morrySnow commented on code in PR #12159:
URL: https://github.com/apache/doris/pull/12159#discussion_r957976831
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -54,86 +56,192 @@
*/
public class AggregateDisassemble extends OneRewriteRuleFactory {
- @Override
- public Rule build() {
- return logicalAggregate().when(agg ->
!agg.isDisassembled()).thenApply(ctx -> {
- LogicalAggregate<GroupPlan> aggregate = ctx.root;
- List<NamedExpression> originOutputExprs =
aggregate.getOutputExpressions();
- List<Expression> originGroupByExprs =
aggregate.getGroupByExpressions();
-
- // 1. generate a map from local aggregate output to global
aggregate expr substitution.
- // inputSubstitutionMap use for replacing expression in global
aggregate
- // replace rule is:
- // a: Expression is a group by key and is a slot reference.
e.g. group by k1
- // b. Expression is a group by key and is an expression.
e.g. group by k1 + 1
- // c. Expression is an aggregate function. e.g. sum(v1) in
select list
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | situation | origin expression | local output expression
| expression in global aggregate |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | a | Ref(k1)#1 | Ref(k1)#1
| Ref(k1)#1 |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | b | Ref(k1)#1 + 1 | A(Ref(k1)#1 + 1, key)#2
| Ref(key)#2 |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | c | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3
| AF(af#3) |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction,
#x: ExprId x
- // 2. collect local aggregate output expressions and local
aggregate group by expression list
- Map<Expression, Expression> inputSubstitutionMap =
Maps.newHashMap();
- List<Expression> localGroupByExprs =
aggregate.getGroupByExpressions();
- List<NamedExpression> localOutputExprs = Lists.newArrayList();
- for (Expression originGroupByExpr : originGroupByExprs) {
- if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+ private Map<Expression, Expression> inputSubstitutionMap =
Maps.newHashMap();
+ private List<NamedExpression> distinctAggFunctionParams = new
ArrayList<>();
+ private List<AggregateFunction> distinctAggFunctions = new ArrayList<>();
+ private List<NamedExpression> distinctOriginOutputExprs = new
ArrayList<>();
Review Comment:
```suggestion
private final Map<Expression, Expression> inputSubstitutionMap =
Maps.newHashMap();
private final List<NamedExpression> distinctAggFunctionParams =
Lists.newArrayList();
private final List<AggregateFunction> distinctAggFunctions =
Lists.newArrayList();
private final List<NamedExpression> distinctOriginOutputExprs =
Lists.newArrayList();
```
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -54,86 +56,192 @@
*/
public class AggregateDisassemble extends OneRewriteRuleFactory {
- @Override
- public Rule build() {
- return logicalAggregate().when(agg ->
!agg.isDisassembled()).thenApply(ctx -> {
- LogicalAggregate<GroupPlan> aggregate = ctx.root;
- List<NamedExpression> originOutputExprs =
aggregate.getOutputExpressions();
- List<Expression> originGroupByExprs =
aggregate.getGroupByExpressions();
-
- // 1. generate a map from local aggregate output to global
aggregate expr substitution.
- // inputSubstitutionMap use for replacing expression in global
aggregate
- // replace rule is:
- // a: Expression is a group by key and is a slot reference.
e.g. group by k1
- // b. Expression is a group by key and is an expression.
e.g. group by k1 + 1
- // c. Expression is an aggregate function. e.g. sum(v1) in
select list
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | situation | origin expression | local output expression
| expression in global aggregate |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | a | Ref(k1)#1 | Ref(k1)#1
| Ref(k1)#1 |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | b | Ref(k1)#1 + 1 | A(Ref(k1)#1 + 1, key)#2
| Ref(key)#2 |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | c | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3
| AF(af#3) |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction,
#x: ExprId x
- // 2. collect local aggregate output expressions and local
aggregate group by expression list
- Map<Expression, Expression> inputSubstitutionMap =
Maps.newHashMap();
- List<Expression> localGroupByExprs =
aggregate.getGroupByExpressions();
- List<NamedExpression> localOutputExprs = Lists.newArrayList();
- for (Expression originGroupByExpr : originGroupByExprs) {
- if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+ private Map<Expression, Expression> inputSubstitutionMap =
Maps.newHashMap();
+ private List<NamedExpression> distinctAggFunctionParams = new
ArrayList<>();
+ private List<AggregateFunction> distinctAggFunctions = new ArrayList<>();
+ private List<NamedExpression> distinctOriginOutputExprs = new
ArrayList<>();
+
+ // only support distinct function with group by
+ // TODO: support distinct function without group by. (add second global
phase)
+ private LogicalAggregate secondDisassemble(LogicalAggregate<GroupPlan>
aggregate) {
+ // origin sql: select count(distinct a) from t1 group by b;
+ // global agg: select a from t1 group by b, a;
+ // second local agg: select count(distinct a) from t1 group by b;
+ // In order to get the second local agg from global agg:
+ // 1. the distinct expression needs to be removed from the output and
the group by of global agg
+ // 2. add distinct agg function back to output
+ List<NamedExpression> originFirstGlobalOutputExprs =
aggregate.getOutputExpressions();
+ List<Expression> originFirstGlobalGroupByExprs =
aggregate.getGroupByExpressions();
+ List<NamedExpression> secondLocalOutputExprs = new
ArrayList<>(aggregate.getOutputExpressions());
+ // add origin distinct function back
+ secondLocalOutputExprs.addAll(distinctOriginOutputExprs);
+ List<Expression> secondLocalGroupByExprs = new
ArrayList<>(aggregate.getGroupByExpressions());
+
+ List<Expression> distinctExprs = new ArrayList<>();
+ // remove distinct param exprs from secondLocalGroupByExprs and
secondLocalOutputExprs
+ for (NamedExpression expression : distinctAggFunctionParams) {
+
secondLocalGroupByExprs.remove(inputSubstitutionMap.get(expression));
+
secondLocalOutputExprs.remove(inputSubstitutionMap.get(expression));
+ distinctExprs.add(inputSubstitutionMap.get(expression));
+ }
+
+ Map<Expression, Expression> secondSubstitutionMap = Maps.newHashMap();
+
+ // replace the original slot reference with the latest one
+ Expression distinctAgg =
distinctAggFunctions.get(0).withChildren(distinctExprs);
+ secondSubstitutionMap.put(distinctAggFunctions.get(0), distinctAgg);
+ List<NamedExpression> secondLocalOutputNamedExprs =
secondLocalOutputExprs.stream()
+ .map(e -> ExpressionReplacer.INSTANCE.visit(e,
secondSubstitutionMap))
+ .map(NamedExpression.class::cast)
+ .collect(Collectors.toList());
+
+ // firstGlobalOutputExprs = originOutputExprs + originGroupByExprs
+ List<NamedExpression> firstGlobalOutputExprs = new
ArrayList<>(originFirstGlobalOutputExprs);
+ for (Expression originGroupByExpr : originFirstGlobalGroupByExprs) {
+ if (firstGlobalOutputExprs.contains(originGroupByExpr)) {
+ continue;
+ }
+ if (originGroupByExpr instanceof SlotReference) {
+ firstGlobalOutputExprs.add((SlotReference) originGroupByExpr);
+ } else {
+ Preconditions.checkState(false);
+ }
+ }
+
+ // generate new plan
+ LogicalAggregate globalAggregate = new LogicalAggregate<>(
+ originFirstGlobalGroupByExprs,
+ firstGlobalOutputExprs,
+ true,
+ aggregate.isNormalized(),
+ AggPhase.GLOBAL,
+ aggregate.child()
+ );
+
+ return new LogicalAggregate<>(
+ secondLocalGroupByExprs,
+ secondLocalOutputNamedExprs,
+ true,
+ aggregate.isNormalized(),
+ AggPhase.DISTINCT_LOCAL,
+ globalAggregate
+ );
+ }
+
+ private LogicalAggregate firstDisassemble(LogicalAggregate<GroupPlan>
aggregate) {
+ List<NamedExpression> originOutputExprs =
aggregate.getOutputExpressions();
+ List<Expression> originGroupByExprs =
aggregate.getGroupByExpressions();
+ // 1. generate a map from local aggregate output to global aggregate
expr substitution.
+ // inputSubstitutionMap use for replacing expression in global
aggregate
+ // replace rule is:
+ // a: Expression is a group by key and is a slot reference.
e.g. group by k1
+ // b. Expression is a group by key and is an expression. e.g.
group by k1 + 1
+ // c. Expression is an aggregate function. e.g. sum(v1) in
select list
+ //
+-----------+---------------------+-------------------------+--------------------------------+
+ // | situation | origin expression | local output expression |
expression in global aggregate |
+ //
+-----------+---------------------+-------------------------+--------------------------------+
+ // | a | Ref(k1)#1 | Ref(k1)#1 |
Ref(k1)#1 |
+ //
+-----------+---------------------+-------------------------+--------------------------------+
+ // | b | Ref(k1)#1 + 1 | A(Ref(k1)#1 + 1, key)#2 |
Ref(key)#2 |
+ //
+-----------+---------------------+-------------------------+--------------------------------+
+ // | c | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3 |
AF(af#3) |
+ //
+-----------+---------------------+-------------------------+--------------------------------+
+ // NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x:
ExprId x
+ // 2. collect local aggregate output expressions and local aggregate
group by expression list
+ List<Expression> localGroupByExprs = aggregate.getGroupByExpressions();
+ List<NamedExpression> localOutputExprs = Lists.newArrayList();
+ for (Expression originGroupByExpr : originGroupByExprs) {
+ if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+ continue;
+ }
+ if (originGroupByExpr instanceof SlotReference) {
+ inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
+ localOutputExprs.add((SlotReference) originGroupByExpr);
+ } else {
+ NamedExpression localOutputExpr = new Alias(originGroupByExpr,
originGroupByExpr.toSql());
+ inputSubstitutionMap.put(originGroupByExpr,
localOutputExpr.toSlot());
+ localOutputExprs.add(localOutputExpr);
+ }
+ }
+ for (NamedExpression originOutputExpr : originOutputExprs) {
+ List<AggregateFunction> aggregateFunctions
+ =
originOutputExpr.collect(AggregateFunction.class::isInstance);
+ for (AggregateFunction aggregateFunction : aggregateFunctions) {
+ if (inputSubstitutionMap.containsKey(aggregateFunction)) {
continue;
}
- if (originGroupByExpr instanceof SlotReference) {
- inputSubstitutionMap.put(originGroupByExpr,
originGroupByExpr);
- localOutputExprs.add((SlotReference) originGroupByExpr);
- } else {
- NamedExpression localOutputExpr = new
Alias(originGroupByExpr, originGroupByExpr.toSql());
- inputSubstitutionMap.put(originGroupByExpr,
localOutputExpr.toSlot());
- localOutputExprs.add(localOutputExpr);
- }
+ NamedExpression localOutputExpr = new Alias(aggregateFunction,
aggregateFunction.toSql());
+ Expression substitutionValue = aggregateFunction.withChildren(
+ Lists.newArrayList(localOutputExpr.toSlot()));
+ inputSubstitutionMap.put(aggregateFunction, substitutionValue);
+ localOutputExprs.add(localOutputExpr);
}
- for (NamedExpression originOutputExpr : originOutputExprs) {
- List<AggregateFunction> aggregateFunctions
- =
originOutputExpr.collect(AggregateFunction.class::isInstance);
- for (AggregateFunction aggregateFunction : aggregateFunctions)
{
- if (inputSubstitutionMap.containsKey(aggregateFunction)) {
- continue;
- }
- NamedExpression localOutputExpr = new
Alias(aggregateFunction, aggregateFunction.toSql());
- Expression substitutionValue =
aggregateFunction.withChildren(
- Lists.newArrayList(localOutputExpr.toSlot()));
- inputSubstitutionMap.put(aggregateFunction,
substitutionValue);
- localOutputExprs.add(localOutputExpr);
+ }
+
+ // 3. replace expression in globalOutputExprs and globalGroupByExprs
+ List<NamedExpression> globalOutputExprs =
aggregate.getOutputExpressions().stream()
+ .map(e -> ExpressionReplacer.INSTANCE.visit(e,
inputSubstitutionMap))
+ .map(NamedExpression.class::cast)
+ .collect(Collectors.toList());
+ List<Expression> globalGroupByExprs = localGroupByExprs.stream()
+ .map(e -> ExpressionReplacer.INSTANCE.visit(e,
inputSubstitutionMap)).collect(Collectors.toList());
+
+ // 4. generate new plan
+ LogicalAggregate localAggregate = new LogicalAggregate<>(
+ localGroupByExprs,
+ localOutputExprs,
+ true,
+ aggregate.isNormalized(),
+ AggPhase.LOCAL,
+ aggregate.child()
+ );
+ return new LogicalAggregate<>(
+ globalGroupByExprs,
+ globalOutputExprs,
+ true,
+ aggregate.isNormalized(),
+ AggPhase.GLOBAL,
+ localAggregate
+ );
+ }
+
+ private void
moveDistinctExprFromOutputToGroupBy(LogicalAggregate<GroupPlan> aggregate) {
+ // for example:
+ // select count(distinct a) from t1 group by b;
+ // => select a from t1 group by b, a;
+ List<NamedExpression> originOutputExprs =
aggregate.getOutputExpressions();
+ List<Expression> originGroupByExprs =
aggregate.getGroupByExpressions();
+
+ for (NamedExpression originOutputExpr : originOutputExprs) {
+ List<AggregateFunction> aggregateFunctions =
+
originOutputExpr.collect(AggregateFunction.class::isInstance);
+ for (AggregateFunction aggregateFunction : aggregateFunctions) {
+ if (aggregateFunction.isDistinct()) {
+ distinctAggFunctions.add(aggregateFunction);
+ distinctOriginOutputExprs.add(originOutputExpr);
}
}
+ }
+ if (!distinctAggFunctions.isEmpty()) {
+ for (Expression expr : distinctAggFunctions.get(0).children()) {
Review Comment:
why only get(0)? pls add some comments
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java:
##########
@@ -63,6 +63,16 @@ public class PlanTranslatorContext {
private final IdGenerator<PlanNodeId> nodeIdGenerator =
PlanNodeId.createGenerator();
+ private boolean hasDistinctAgg = false;
+
+ public boolean hasDistinctAgg() {
Review Comment:
should not use a global flag to do that, instead we should put it into
PhysicalAggregate or just return true if AggregateFunction in Aggregate has
distinct
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -230,6 +246,13 @@ public PlanFragment
visitPhysicalAggregate(PhysicalAggregate<Plan> aggregate, Pl
case GLOBAL:
inputPlanFragment.updateDataPartition(mergePartition);
return inputPlanFragment;
+ case DISTINCT_LOCAL:
+ AggregationNode globalAggNode = (AggregationNode)
aggregationNode.getChild(0);
+ globalAggNode.unsetNeedsFinalize();
+ globalAggNode.setIntermediateTuple();
+ inputPlanFragment.updateDataPartition(mergePartition);
Review Comment:
when we do distinct agg, the partition exprs is different from group by
exprs, where to process this?
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -54,86 +56,192 @@
*/
public class AggregateDisassemble extends OneRewriteRuleFactory {
- @Override
- public Rule build() {
- return logicalAggregate().when(agg ->
!agg.isDisassembled()).thenApply(ctx -> {
- LogicalAggregate<GroupPlan> aggregate = ctx.root;
- List<NamedExpression> originOutputExprs =
aggregate.getOutputExpressions();
- List<Expression> originGroupByExprs =
aggregate.getGroupByExpressions();
-
- // 1. generate a map from local aggregate output to global
aggregate expr substitution.
- // inputSubstitutionMap use for replacing expression in global
aggregate
- // replace rule is:
- // a: Expression is a group by key and is a slot reference.
e.g. group by k1
- // b. Expression is a group by key and is an expression.
e.g. group by k1 + 1
- // c. Expression is an aggregate function. e.g. sum(v1) in
select list
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | situation | origin expression | local output expression
| expression in global aggregate |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | a | Ref(k1)#1 | Ref(k1)#1
| Ref(k1)#1 |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | b | Ref(k1)#1 + 1 | A(Ref(k1)#1 + 1, key)#2
| Ref(key)#2 |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | c | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3
| AF(af#3) |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction,
#x: ExprId x
- // 2. collect local aggregate output expressions and local
aggregate group by expression list
- Map<Expression, Expression> inputSubstitutionMap =
Maps.newHashMap();
- List<Expression> localGroupByExprs =
aggregate.getGroupByExpressions();
- List<NamedExpression> localOutputExprs = Lists.newArrayList();
- for (Expression originGroupByExpr : originGroupByExprs) {
- if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+ private Map<Expression, Expression> inputSubstitutionMap =
Maps.newHashMap();
+ private List<NamedExpression> distinctAggFunctionParams = new
ArrayList<>();
+ private List<AggregateFunction> distinctAggFunctions = new ArrayList<>();
+ private List<NamedExpression> distinctOriginOutputExprs = new
ArrayList<>();
+
+ // only support distinct function with group by
+ // TODO: support distinct function without group by. (add second global
phase)
+ private LogicalAggregate secondDisassemble(LogicalAggregate<GroupPlan>
aggregate) {
+ // origin sql: select count(distinct a) from t1 group by b;
+ // global agg: select a from t1 group by b, a;
+ // second local agg: select count(distinct a) from t1 group by b;
+ // In order to get the second local agg from global agg:
+ // 1. the distinct expression needs to be removed from the output and
the group by of global agg
+ // 2. add distinct agg function back to output
+ List<NamedExpression> originFirstGlobalOutputExprs =
aggregate.getOutputExpressions();
Review Comment:
we'd better use phase name same as AggPhase i.e.
- LOCAL
- GLOBAL
- DISTINCT_LOCAL
- DISTINCT_GLOBAL
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -54,86 +56,192 @@
*/
public class AggregateDisassemble extends OneRewriteRuleFactory {
- @Override
- public Rule build() {
- return logicalAggregate().when(agg ->
!agg.isDisassembled()).thenApply(ctx -> {
- LogicalAggregate<GroupPlan> aggregate = ctx.root;
- List<NamedExpression> originOutputExprs =
aggregate.getOutputExpressions();
- List<Expression> originGroupByExprs =
aggregate.getGroupByExpressions();
-
- // 1. generate a map from local aggregate output to global
aggregate expr substitution.
- // inputSubstitutionMap use for replacing expression in global
aggregate
- // replace rule is:
- // a: Expression is a group by key and is a slot reference.
e.g. group by k1
- // b. Expression is a group by key and is an expression.
e.g. group by k1 + 1
- // c. Expression is an aggregate function. e.g. sum(v1) in
select list
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | situation | origin expression | local output expression
| expression in global aggregate |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | a | Ref(k1)#1 | Ref(k1)#1
| Ref(k1)#1 |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | b | Ref(k1)#1 + 1 | A(Ref(k1)#1 + 1, key)#2
| Ref(key)#2 |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | c | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3
| AF(af#3) |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction,
#x: ExprId x
- // 2. collect local aggregate output expressions and local
aggregate group by expression list
- Map<Expression, Expression> inputSubstitutionMap =
Maps.newHashMap();
- List<Expression> localGroupByExprs =
aggregate.getGroupByExpressions();
- List<NamedExpression> localOutputExprs = Lists.newArrayList();
- for (Expression originGroupByExpr : originGroupByExprs) {
- if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+ private Map<Expression, Expression> inputSubstitutionMap =
Maps.newHashMap();
+ private List<NamedExpression> distinctAggFunctionParams = new
ArrayList<>();
+ private List<AggregateFunction> distinctAggFunctions = new ArrayList<>();
+ private List<NamedExpression> distinctOriginOutputExprs = new
ArrayList<>();
+
+ // only support distinct function with group by
+ // TODO: support distinct function without group by. (add second global
phase)
+ private LogicalAggregate secondDisassemble(LogicalAggregate<GroupPlan>
aggregate) {
+ // origin sql: select count(distinct a) from t1 group by b;
+ // global agg: select a from t1 group by b, a;
+ // second local agg: select count(distinct a) from t1 group by b;
+ // In order to get the second local agg from global agg:
+ // 1. the distinct expression needs to be removed from the output and
the group by of global agg
+ // 2. add distinct agg function back to output
+ List<NamedExpression> originFirstGlobalOutputExprs =
aggregate.getOutputExpressions();
+ List<Expression> originFirstGlobalGroupByExprs =
aggregate.getGroupByExpressions();
+ List<NamedExpression> secondLocalOutputExprs = new
ArrayList<>(aggregate.getOutputExpressions());
+ // add origin distinct function back
+ secondLocalOutputExprs.addAll(distinctOriginOutputExprs);
+ List<Expression> secondLocalGroupByExprs = new
ArrayList<>(aggregate.getGroupByExpressions());
+
+ List<Expression> distinctExprs = new ArrayList<>();
+ // remove distinct param exprs from secondLocalGroupByExprs and
secondLocalOutputExprs
+ for (NamedExpression expression : distinctAggFunctionParams) {
+
secondLocalGroupByExprs.remove(inputSubstitutionMap.get(expression));
+
secondLocalOutputExprs.remove(inputSubstitutionMap.get(expression));
Review Comment:
secondLocalOutputExprs has NamedExpression, such as Alias and SlotReference,
but i don't think key of inputSubstitutionMap has the same Alias
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -54,86 +56,192 @@
*/
public class AggregateDisassemble extends OneRewriteRuleFactory {
- @Override
- public Rule build() {
- return logicalAggregate().when(agg ->
!agg.isDisassembled()).thenApply(ctx -> {
- LogicalAggregate<GroupPlan> aggregate = ctx.root;
- List<NamedExpression> originOutputExprs =
aggregate.getOutputExpressions();
- List<Expression> originGroupByExprs =
aggregate.getGroupByExpressions();
-
- // 1. generate a map from local aggregate output to global
aggregate expr substitution.
- // inputSubstitutionMap use for replacing expression in global
aggregate
- // replace rule is:
- // a: Expression is a group by key and is a slot reference.
e.g. group by k1
- // b. Expression is a group by key and is an expression.
e.g. group by k1 + 1
- // c. Expression is an aggregate function. e.g. sum(v1) in
select list
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | situation | origin expression | local output expression
| expression in global aggregate |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | a | Ref(k1)#1 | Ref(k1)#1
| Ref(k1)#1 |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | b | Ref(k1)#1 + 1 | A(Ref(k1)#1 + 1, key)#2
| Ref(key)#2 |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | c | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3
| AF(af#3) |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction,
#x: ExprId x
- // 2. collect local aggregate output expressions and local
aggregate group by expression list
- Map<Expression, Expression> inputSubstitutionMap =
Maps.newHashMap();
- List<Expression> localGroupByExprs =
aggregate.getGroupByExpressions();
- List<NamedExpression> localOutputExprs = Lists.newArrayList();
- for (Expression originGroupByExpr : originGroupByExprs) {
- if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+ private Map<Expression, Expression> inputSubstitutionMap =
Maps.newHashMap();
+ private List<NamedExpression> distinctAggFunctionParams = new
ArrayList<>();
+ private List<AggregateFunction> distinctAggFunctions = new ArrayList<>();
+ private List<NamedExpression> distinctOriginOutputExprs = new
ArrayList<>();
+
+ // only support distinct function with group by
+ // TODO: support distinct function without group by. (add second global
phase)
+ private LogicalAggregate secondDisassemble(LogicalAggregate<GroupPlan>
aggregate) {
+ // origin sql: select count(distinct a) from t1 group by b;
+ // global agg: select a from t1 group by b, a;
+ // second local agg: select count(distinct a) from t1 group by b;
+ // In order to get the second local agg from global agg:
+ // 1. the distinct expression needs to be removed from the output and
the group by of global agg
+ // 2. add distinct agg function back to output
+ List<NamedExpression> originFirstGlobalOutputExprs =
aggregate.getOutputExpressions();
+ List<Expression> originFirstGlobalGroupByExprs =
aggregate.getGroupByExpressions();
+ List<NamedExpression> secondLocalOutputExprs = new
ArrayList<>(aggregate.getOutputExpressions());
+ // add origin distinct function back
+ secondLocalOutputExprs.addAll(distinctOriginOutputExprs);
Review Comment:
secondLocalOutputExprs generated from aggregate output expressions, i think
it already has origin distinct function?
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -54,86 +56,192 @@
*/
public class AggregateDisassemble extends OneRewriteRuleFactory {
Review Comment:
pls update this class's java doc comments, add distinct situation
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -54,86 +56,192 @@
*/
public class AggregateDisassemble extends OneRewriteRuleFactory {
- @Override
- public Rule build() {
- return logicalAggregate().when(agg ->
!agg.isDisassembled()).thenApply(ctx -> {
- LogicalAggregate<GroupPlan> aggregate = ctx.root;
- List<NamedExpression> originOutputExprs =
aggregate.getOutputExpressions();
- List<Expression> originGroupByExprs =
aggregate.getGroupByExpressions();
-
- // 1. generate a map from local aggregate output to global
aggregate expr substitution.
- // inputSubstitutionMap use for replacing expression in global
aggregate
- // replace rule is:
- // a: Expression is a group by key and is a slot reference.
e.g. group by k1
- // b. Expression is a group by key and is an expression.
e.g. group by k1 + 1
- // c. Expression is an aggregate function. e.g. sum(v1) in
select list
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | situation | origin expression | local output expression
| expression in global aggregate |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | a | Ref(k1)#1 | Ref(k1)#1
| Ref(k1)#1 |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | b | Ref(k1)#1 + 1 | A(Ref(k1)#1 + 1, key)#2
| Ref(key)#2 |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // | c | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3
| AF(af#3) |
- //
+-----------+---------------------+-------------------------+--------------------------------+
- // NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction,
#x: ExprId x
- // 2. collect local aggregate output expressions and local
aggregate group by expression list
- Map<Expression, Expression> inputSubstitutionMap =
Maps.newHashMap();
- List<Expression> localGroupByExprs =
aggregate.getGroupByExpressions();
- List<NamedExpression> localOutputExprs = Lists.newArrayList();
- for (Expression originGroupByExpr : originGroupByExprs) {
- if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+ private Map<Expression, Expression> inputSubstitutionMap =
Maps.newHashMap();
+ private List<NamedExpression> distinctAggFunctionParams = new
ArrayList<>();
+ private List<AggregateFunction> distinctAggFunctions = new ArrayList<>();
+ private List<NamedExpression> distinctOriginOutputExprs = new
ArrayList<>();
+
+ // only support distinct function with group by
+ // TODO: support distinct function without group by. (add second global
phase)
+ private LogicalAggregate secondDisassemble(LogicalAggregate<GroupPlan>
aggregate) {
+ // origin sql: select count(distinct a) from t1 group by b;
+ // global agg: select a from t1 group by b, a;
+ // second local agg: select count(distinct a) from t1 group by b;
+ // In order to get the second local agg from global agg:
+ // 1. the distinct expression needs to be removed from the output and
the group by of global agg
+ // 2. add distinct agg function back to output
+ List<NamedExpression> originFirstGlobalOutputExprs =
aggregate.getOutputExpressions();
+ List<Expression> originFirstGlobalGroupByExprs =
aggregate.getGroupByExpressions();
+ List<NamedExpression> secondLocalOutputExprs = new
ArrayList<>(aggregate.getOutputExpressions());
+ // add origin distinct function back
+ secondLocalOutputExprs.addAll(distinctOriginOutputExprs);
+ List<Expression> secondLocalGroupByExprs = new
ArrayList<>(aggregate.getGroupByExpressions());
+
+ List<Expression> distinctExprs = new ArrayList<>();
+ // remove distinct param exprs from secondLocalGroupByExprs and
secondLocalOutputExprs
+ for (NamedExpression expression : distinctAggFunctionParams) {
+
secondLocalGroupByExprs.remove(inputSubstitutionMap.get(expression));
+
secondLocalOutputExprs.remove(inputSubstitutionMap.get(expression));
+ distinctExprs.add(inputSubstitutionMap.get(expression));
+ }
+
+ Map<Expression, Expression> secondSubstitutionMap = Maps.newHashMap();
+
+ // replace the original slot reference with the latest one
+ Expression distinctAgg =
distinctAggFunctions.get(0).withChildren(distinctExprs);
+ secondSubstitutionMap.put(distinctAggFunctions.get(0), distinctAgg);
+ List<NamedExpression> secondLocalOutputNamedExprs =
secondLocalOutputExprs.stream()
+ .map(e -> ExpressionReplacer.INSTANCE.visit(e,
secondSubstitutionMap))
+ .map(NamedExpression.class::cast)
+ .collect(Collectors.toList());
+
+ // firstGlobalOutputExprs = originOutputExprs + originGroupByExprs
+ List<NamedExpression> firstGlobalOutputExprs = new
ArrayList<>(originFirstGlobalOutputExprs);
+ for (Expression originGroupByExpr : originFirstGlobalGroupByExprs) {
+ if (firstGlobalOutputExprs.contains(originGroupByExpr)) {
+ continue;
+ }
+ if (originGroupByExpr instanceof SlotReference) {
+ firstGlobalOutputExprs.add((SlotReference) originGroupByExpr);
+ } else {
+ Preconditions.checkState(false);
+ }
+ }
+
+ // generate new plan
+ LogicalAggregate globalAggregate = new LogicalAggregate<>(
+ originFirstGlobalGroupByExprs,
+ firstGlobalOutputExprs,
+ true,
+ aggregate.isNormalized(),
+ AggPhase.GLOBAL,
+ aggregate.child()
+ );
+
+ return new LogicalAggregate<>(
+ secondLocalGroupByExprs,
+ secondLocalOutputNamedExprs,
+ true,
+ aggregate.isNormalized(),
+ AggPhase.DISTINCT_LOCAL,
+ globalAggregate
+ );
+ }
+
+ private LogicalAggregate firstDisassemble(LogicalAggregate<GroupPlan>
aggregate) {
+ List<NamedExpression> originOutputExprs =
aggregate.getOutputExpressions();
+ List<Expression> originGroupByExprs =
aggregate.getGroupByExpressions();
+ // 1. generate a map from local aggregate output to global aggregate
expr substitution.
+ // inputSubstitutionMap use for replacing expression in global
aggregate
+ // replace rule is:
+ // a: Expression is a group by key and is a slot reference.
e.g. group by k1
+ // b. Expression is a group by key and is an expression. e.g.
group by k1 + 1
+ // c. Expression is an aggregate function. e.g. sum(v1) in
select list
+ //
+-----------+---------------------+-------------------------+--------------------------------+
+ // | situation | origin expression | local output expression |
expression in global aggregate |
+ //
+-----------+---------------------+-------------------------+--------------------------------+
+ // | a | Ref(k1)#1 | Ref(k1)#1 |
Ref(k1)#1 |
+ //
+-----------+---------------------+-------------------------+--------------------------------+
+ // | b | Ref(k1)#1 + 1 | A(Ref(k1)#1 + 1, key)#2 |
Ref(key)#2 |
+ //
+-----------+---------------------+-------------------------+--------------------------------+
+ // | c | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3 |
AF(af#3) |
+ //
+-----------+---------------------+-------------------------+--------------------------------+
+ // NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x:
ExprId x
+ // 2. collect local aggregate output expressions and local aggregate
group by expression list
+ List<Expression> localGroupByExprs = aggregate.getGroupByExpressions();
+ List<NamedExpression> localOutputExprs = Lists.newArrayList();
+ for (Expression originGroupByExpr : originGroupByExprs) {
+ if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+ continue;
+ }
+ if (originGroupByExpr instanceof SlotReference) {
+ inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
+ localOutputExprs.add((SlotReference) originGroupByExpr);
+ } else {
+ NamedExpression localOutputExpr = new Alias(originGroupByExpr,
originGroupByExpr.toSql());
+ inputSubstitutionMap.put(originGroupByExpr,
localOutputExpr.toSlot());
+ localOutputExprs.add(localOutputExpr);
+ }
+ }
+ for (NamedExpression originOutputExpr : originOutputExprs) {
+ List<AggregateFunction> aggregateFunctions
+ =
originOutputExpr.collect(AggregateFunction.class::isInstance);
+ for (AggregateFunction aggregateFunction : aggregateFunctions) {
+ if (inputSubstitutionMap.containsKey(aggregateFunction)) {
continue;
}
- if (originGroupByExpr instanceof SlotReference) {
- inputSubstitutionMap.put(originGroupByExpr,
originGroupByExpr);
- localOutputExprs.add((SlotReference) originGroupByExpr);
- } else {
- NamedExpression localOutputExpr = new
Alias(originGroupByExpr, originGroupByExpr.toSql());
- inputSubstitutionMap.put(originGroupByExpr,
localOutputExpr.toSlot());
- localOutputExprs.add(localOutputExpr);
- }
+ NamedExpression localOutputExpr = new Alias(aggregateFunction,
aggregateFunction.toSql());
+ Expression substitutionValue = aggregateFunction.withChildren(
+ Lists.newArrayList(localOutputExpr.toSlot()));
+ inputSubstitutionMap.put(aggregateFunction, substitutionValue);
+ localOutputExprs.add(localOutputExpr);
}
- for (NamedExpression originOutputExpr : originOutputExprs) {
- List<AggregateFunction> aggregateFunctions
- =
originOutputExpr.collect(AggregateFunction.class::isInstance);
- for (AggregateFunction aggregateFunction : aggregateFunctions)
{
- if (inputSubstitutionMap.containsKey(aggregateFunction)) {
- continue;
- }
- NamedExpression localOutputExpr = new
Alias(aggregateFunction, aggregateFunction.toSql());
- Expression substitutionValue =
aggregateFunction.withChildren(
- Lists.newArrayList(localOutputExpr.toSlot()));
- inputSubstitutionMap.put(aggregateFunction,
substitutionValue);
- localOutputExprs.add(localOutputExpr);
+ }
+
+ // 3. replace expression in globalOutputExprs and globalGroupByExprs
+ List<NamedExpression> globalOutputExprs =
aggregate.getOutputExpressions().stream()
+ .map(e -> ExpressionReplacer.INSTANCE.visit(e,
inputSubstitutionMap))
+ .map(NamedExpression.class::cast)
+ .collect(Collectors.toList());
+ List<Expression> globalGroupByExprs = localGroupByExprs.stream()
+ .map(e -> ExpressionReplacer.INSTANCE.visit(e,
inputSubstitutionMap)).collect(Collectors.toList());
+
+ // 4. generate new plan
+ LogicalAggregate localAggregate = new LogicalAggregate<>(
+ localGroupByExprs,
+ localOutputExprs,
+ true,
+ aggregate.isNormalized(),
+ AggPhase.LOCAL,
+ aggregate.child()
+ );
+ return new LogicalAggregate<>(
+ globalGroupByExprs,
+ globalOutputExprs,
+ true,
+ aggregate.isNormalized(),
+ AggPhase.GLOBAL,
+ localAggregate
+ );
+ }
+
+ private void
moveDistinctExprFromOutputToGroupBy(LogicalAggregate<GroupPlan> aggregate) {
+ // for example:
+ // select count(distinct a) from t1 group by b;
+ // => select a from t1 group by b, a;
+ List<NamedExpression> originOutputExprs =
aggregate.getOutputExpressions();
+ List<Expression> originGroupByExprs =
aggregate.getGroupByExpressions();
+
+ for (NamedExpression originOutputExpr : originOutputExprs) {
+ List<AggregateFunction> aggregateFunctions =
+
originOutputExpr.collect(AggregateFunction.class::isInstance);
+ for (AggregateFunction aggregateFunction : aggregateFunctions) {
+ if (aggregateFunction.isDistinct()) {
+ distinctAggFunctions.add(aggregateFunction);
+ distinctOriginOutputExprs.add(originOutputExpr);
}
}
+ }
+ if (!distinctAggFunctions.isEmpty()) {
+ for (Expression expr : distinctAggFunctions.get(0).children()) {
+ distinctAggFunctionParams.add((NamedExpression) expr);
Review Comment:
why could assume expr is `NamedExpression`? how about `count(distinct a + 1)`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]