This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new b32aac9195 [feature](Nereids)add normalize aggregate rule (#12013)
b32aac9195 is described below
commit b32aac919528f47a58247b4d05aa973a8643f885
Author: morrySnow <[email protected]>
AuthorDate: Wed Aug 24 18:30:18 2022 +0800
[feature](Nereids)add normalize aggregate rule (#12013)
---
.../org/apache/doris/nereids/rules/RuleType.java | 1 +
.../expression/rewrite/ExpressionRewrite.java | 4 +-
.../rules/rewrite/AggregateDisassemble.java | 3 +-
.../rules/rewrite/logical/NormalizeAggregate.java | 138 +++++++++++++++++
.../trees/plans/logical/LogicalAggregate.java | 27 +++-
.../trees/plans/logical/LogicalOlapScan.java | 2 +-
.../rewrite/logical/NormalizeAggregateTest.java | 168 +++++++++++++++++++++
7 files changed, 331 insertions(+), 12 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 7f4c22ba71..9dd73d04b4 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -45,6 +45,7 @@ public enum RuleType {
CHECK_ANALYSIS(RuleTypeClass.CHECK),
// rewrite rules
+ NORMALIZE_AGGREGATE(RuleTypeClass.REWRITE),
AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE),
COLUMN_PRUNE_PROJECTION(RuleTypeClass.REWRITE),
ELIMINATE_ALIAS_NODE(RuleTypeClass.REWRITE),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java
index b285eaa2fa..96e61abc10 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java
@@ -96,8 +96,8 @@ public class ExpressionRewrite implements RewriteRuleFactory {
if (outputExpressions.containsAll(newOutputExpressions)) {
return agg;
}
- return new LogicalAggregate<>(newGroupByExprs,
newOutputExpressions, agg.isDisassembled(),
- agg.getAggPhase(), agg.child());
+ return new LogicalAggregate<>(newGroupByExprs,
newOutputExpressions,
+ agg.isDisassembled(), agg.isNormalized(),
agg.getAggPhase(), agg.child());
}).toRule(RuleType.REWRITE_AGG_EXPRESSION);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
index 8a0363e103..1e8da9a14f 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
@@ -51,7 +51,6 @@ import java.util.stream.Collectors;
* TODO:
* 1. use different class represent different phase aggregate
* 2. if instance count is 1, shouldn't disassemble the agg plan
- * 3. we need another rule to removing duplicated expressions in group by
expression list
*/
public class AggregateDisassemble extends OneRewriteRuleFactory {
@@ -123,6 +122,7 @@ public class AggregateDisassemble extends
OneRewriteRuleFactory {
localGroupByExprs,
localOutputExprs,
true,
+ aggregate.isNormalized(),
AggPhase.LOCAL,
aggregate.child()
);
@@ -130,6 +130,7 @@ public class AggregateDisassemble extends
OneRewriteRuleFactory {
globalGroupByExprs,
globalOutputExprs,
true,
+ aggregate.isNormalized(),
AggPhase.GLOBAL,
localAggregate
);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
new file mode 100644
index 0000000000..5aa70a0af3
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
@@ -0,0 +1,138 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionReplacer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * normalize aggregate's group keys to SlotReference and generate a
LogicalProject top on LogicalAggregate
+ * to hold to order of aggregate output, since aggregate output's order could
change when we do translate.
+ *
+ * Apply this rule could simplify the processing of enforce and translate.
+ *
+ * Original Plan:
+ * Aggregate(
+ * keys:[k1#1, K2#2 + 1],
+ * outputs:[k1#1, Alias(K2# + 1)#4, Alias(k1#1 + 1)#5, Alias(SUM(v1#3))#6,
+ * Alias(SUM(v1#3 + 1))#7, Alias(SUM(v1#3) + 1)#8])
+ *
+ * After rule:
+ * Project(k1#1, Alias(SR#9)#4, Alias(k1#1 + 1)#5, Alias(SR#10))#6,
Alias(SR#11))#7, Alias(SR#10 + 1)#8)
+ * +-- Aggregate(keys:[k1#1, SR#9], outputs:[k1#1, SR#9, Alias(SUM(v1#3))#10,
Alias(SUM(v1#3 + 1))#11])
+ * +-- Project(k1#1, Alias(K2#2 + 1)#9, v1#3)
+ *
+ * More example could get from UT {@link NormalizeAggregateTest}
+ */
+public class NormalizeAggregate extends OneRewriteRuleFactory {
+ @Override
+ public Rule build() {
+ return logicalAggregate().when(aggregate ->
!aggregate.isNormalized()).then(aggregate -> {
+ // substitution map used to substitute expression in aggregate's
output to use it as top projections
+ Map<Expression, Expression> substitutionMap = Maps.newHashMap();
+ List<Expression> keys = aggregate.getGroupByExpressions();
+ List<NamedExpression> newOutputs = Lists.newArrayList();
+
+ // keys
+ Map<Boolean, List<Expression>> partitionedKeys = keys.stream()
+
.collect(Collectors.groupingBy(SlotReference.class::isInstance));
+ List<Expression> newKeys = Lists.newArrayList();
+ List<NamedExpression> bottomProjections = Lists.newArrayList();
+ if (partitionedKeys.containsKey(false)) {
+ // process non-SlotReference keys
+ newKeys.addAll(partitionedKeys.get(false).stream()
+ .map(e -> new Alias(e, e.toSql()))
+ .peek(a -> substitutionMap.put(a.child(), a.toSlot()))
+ .peek(bottomProjections::add)
+ .map(Alias::toSlot)
+ .collect(Collectors.toList()));
+ }
+ if (partitionedKeys.containsKey(true)) {
+ // process SlotReference keys
+ partitionedKeys.get(true).stream()
+ .map(SlotReference.class::cast)
+ .peek(s -> substitutionMap.put(s, s))
+ .peek(bottomProjections::add)
+ .forEach(newKeys::add);
+ }
+ // add all necessary key to output
+ substitutionMap.entrySet().stream()
+ .filter(kv -> aggregate.getOutputExpressions().stream()
+ .anyMatch(e -> e.anyMatch(kv.getKey()::equals)))
+ .map(Entry::getValue)
+ .map(NamedExpression.class::cast)
+ .forEach(newOutputs::add);
+
+ // if we generate bottom, we need to generate to project too.
+ // output
+ List<NamedExpression> outputs = aggregate.getOutputExpressions();
+ Map<Boolean, List<NamedExpression>> partitionedOutputs =
outputs.stream()
+ .collect(Collectors.groupingBy(e ->
e.anyMatch(AggregateFunction.class::isInstance)));
+ if (partitionedOutputs.containsKey(true)) {
+ // process expressions that contain aggregate function
+ Set<AggregateFunction> aggregateFunctions =
partitionedOutputs.get(true).stream()
+ .flatMap(e ->
e.<List<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream())
+ .collect(Collectors.toSet());
+ newOutputs.addAll(aggregateFunctions.stream()
+ .map(f -> new Alias(f, f.toSql()))
+ .peek(a -> substitutionMap.put(a.child(), a.toSlot()))
+ .collect(Collectors.toList()));
+ // add slot references in aggregate function to bottom
projections
+ bottomProjections.addAll(aggregateFunctions.stream()
+ .flatMap(f ->
f.<List<SlotReference>>collect(SlotReference.class::isInstance).stream())
+ .map(SlotReference.class::cast)
+ .collect(Collectors.toSet()));
+ }
+
+
+ // assemble
+ LogicalPlan root = aggregate.child();
+ if (partitionedKeys.containsKey(false)) {
+ root = new LogicalProject<>(bottomProjections, root);
+ }
+ root = new LogicalAggregate<>(newKeys, newOutputs,
aggregate.isDisassembled(),
+ true, aggregate.getAggPhase(), root);
+ List<NamedExpression> projections = outputs.stream()
+ .map(e -> ExpressionReplacer.INSTANCE.visit(e,
substitutionMap))
+ .map(NamedExpression.class::cast)
+ .collect(Collectors.toList());
+ root = new LogicalProject<>(projections, root);
+
+ return root;
+ }).toRule(RuleType.NORMALIZE_AGGREGATE);
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
index 06ec298851..019c090a96 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
@@ -52,6 +52,7 @@ import java.util.Optional;
public class LogicalAggregate<CHILD_TYPE extends Plan> extends
LogicalUnary<CHILD_TYPE> implements Aggregate {
private final boolean disassembled;
+ private final boolean normalized;
private final List<Expression> groupByExpressions;
private final List<NamedExpression> outputExpressions;
private final AggPhase aggPhase;
@@ -63,16 +64,18 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
List<Expression> groupByExpressions,
List<NamedExpression> outputExpressions,
CHILD_TYPE child) {
- this(groupByExpressions, outputExpressions, false, AggPhase.GLOBAL,
child);
+ this(groupByExpressions, outputExpressions, false, false,
AggPhase.GLOBAL, child);
}
public LogicalAggregate(
List<Expression> groupByExpressions,
List<NamedExpression> outputExpressions,
boolean disassembled,
+ boolean normalized,
AggPhase aggPhase,
CHILD_TYPE child) {
- this(groupByExpressions, outputExpressions, disassembled, aggPhase,
Optional.empty(), Optional.empty(), child);
+ this(groupByExpressions, outputExpressions, disassembled, normalized,
+ aggPhase, Optional.empty(), Optional.empty(), child);
}
/**
@@ -82,6 +85,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
List<Expression> groupByExpressions,
List<NamedExpression> outputExpressions,
boolean disassembled,
+ boolean normalized,
AggPhase aggPhase,
Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties,
@@ -90,6 +94,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
this.groupByExpressions = groupByExpressions;
this.outputExpressions = outputExpressions;
this.disassembled = disassembled;
+ this.normalized = normalized;
this.aggPhase = aggPhase;
}
@@ -136,6 +141,10 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
return disassembled;
}
+ public boolean isNormalized() {
+ return normalized;
+ }
+
/**
* Determine the equality with another plan
*/
@@ -160,23 +169,25 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
@Override
public LogicalAggregate<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
- return new LogicalAggregate(groupByExpressions, outputExpressions,
disassembled, aggPhase, children.get(0));
+ return new LogicalAggregate<>(groupByExpressions, outputExpressions,
+ disassembled, normalized, aggPhase, children.get(0));
}
@Override
public LogicalAggregate<Plan>
withGroupExpression(Optional<GroupExpression> groupExpression) {
- return new LogicalAggregate(groupByExpressions, outputExpressions,
disassembled, aggPhase, groupExpression,
- Optional.of(logicalProperties), children.get(0));
+ return new LogicalAggregate<>(groupByExpressions, outputExpressions,
+ disassembled, normalized, aggPhase, groupExpression,
Optional.of(logicalProperties), children.get(0));
}
@Override
public LogicalAggregate<Plan>
withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
- return new LogicalAggregate(groupByExpressions, outputExpressions,
disassembled, aggPhase, Optional.empty(),
- logicalProperties, children.get(0));
+ return new LogicalAggregate<>(groupByExpressions, outputExpressions,
+ disassembled, normalized, aggPhase, Optional.empty(),
logicalProperties, children.get(0));
}
public LogicalAggregate<Plan> withGroupByAndOutput(List<Expression>
groupByExprList,
List<NamedExpression>
outputExpressionList) {
- return new LogicalAggregate(groupByExprList, outputExpressionList,
disassembled, aggPhase, child());
+ return new LogicalAggregate<>(groupByExprList, outputExpressionList,
+ disassembled, normalized, aggPhase, child());
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java
index c071a84ea1..a2d5372327 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java
@@ -68,7 +68,7 @@ public class LogicalOlapScan extends LogicalRelation {
return "ScanOlapTable ("
+ qualifiedName()
+ ", output: "
- +
computeOutput().stream().map(Objects::toString).collect(Collectors.joining(",
", "[", "]"))
+ +
getOutput().stream().map(Objects::toString).collect(Collectors.joining(", ",
"[", "]"))
+ ")";
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java
new file mode 100644
index 0000000000..fd44a0d628
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java
@@ -0,0 +1,168 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Multiply;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.Sum;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.util.FieldChecker;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInstance;
+
+import java.util.List;
+
+@TestInstance(TestInstance.Lifecycle.PER_CLASS)
+public class NormalizeAggregateTest implements PatternMatchSupported {
+ private Plan rStudent;
+
+ @BeforeAll
+ public final void beforeAll() {
+ rStudent = new LogicalOlapScan(PlanConstructor.student,
ImmutableList.of("student"));
+ }
+
+ /**
+ * original plan:
+ * LogicalAggregate (phase: [GLOBAL], output: [name#2, sum(id#0) AS
`sum`#4], groupBy: [name#2])
+ * +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2,
age#3])
+ *
+ * after rewrite:
+ * LogicalProject (name#2, sum(id)#5 AS `sum`#4)
+ * +--LogicalAggregate (phase: [GLOBAL], output: [name#2, sum(id#0) AS
`sum(id)`#5], groupBy: [name#2])
+ * +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2,
age#3])
+ */
+ @Test
+ public void testSimpleKeyWithSimpleAggregateFunction() {
+ NamedExpression key = rStudent.getOutput().get(2).toSlot();
+ NamedExpression aggregateFunction = new Alias(new
Sum(rStudent.getOutput().get(0).toSlot()), "sum");
+ List<Expression> groupExpressionList = Lists.newArrayList(key);
+ List<NamedExpression> outputExpressionList = Lists.newArrayList(key,
aggregateFunction);
+ Plan root = new LogicalAggregate<>(groupExpressionList,
outputExpressionList, rStudent);
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+ .applyTopDown(new NormalizeAggregate())
+ .matchesFromRoot(
+ logicalProject(
+ logicalAggregate(
+ logicalOlapScan()
+
).when(FieldChecker.check("groupByExpressions", ImmutableList.of(key)))
+ .when(aggregate ->
aggregate.getOutputExpressions().get(0).equals(key))
+ .when(aggregate ->
aggregate.getOutputExpressions().get(1).child(0).equals(aggregateFunction.child(0)))
+ .when(FieldChecker.check("normalized",
true))
+ ).when(project ->
project.getProjects().get(0).equals(key))
+ .when(project -> project.getProjects().get(1)
instanceof Alias)
+ .when(project -> ((Alias)
(project.getProjects().get(1))).getExprId().equals(aggregateFunction.getExprId()))
+ .when(project ->
project.getProjects().get(1).child(0) instanceof SlotReference)
+ );
+ }
+
+ /**
+ * original plan:
+ * LogicalAggregate (phase: [GLOBAL], output: [(sum((id#0 * 1)) + 2) AS
`(sum((id * 1)) + 2)`#4], groupBy: [name#2])
+ * +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2,
age#3])
+ *
+ * after rewrite:
+ * LogicalProject ((sum((id * 1))#5 + 2) AS `(sum((id * 1)) + 2)`#4)
+ * +--LogicalAggregate (phase: [GLOBAL], output: [sum((id#0 * 1)) AS
`sum((id * 1))`#5], groupBy: [name#2])
+ * +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2,
age#3])
+ */
+ @Test
+ public void testComplexFuncWithComplexOutputOfFunc() {
+ NamedExpression key = rStudent.getOutput().get(2).toSlot();
+ List<Expression> groupExpressionList = Lists.newArrayList(key);
+ Expression aggregateFunction = new Sum(new
Multiply(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(1)));
+ Expression complexOutput = new Add(aggregateFunction, new
IntegerLiteral(2));
+ Alias output = new Alias(complexOutput, complexOutput.toSql());
+ List<NamedExpression> outputExpressionList =
Lists.newArrayList(output);
+ Plan root = new LogicalAggregate<>(groupExpressionList,
outputExpressionList, rStudent);
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+ .applyTopDown(new NormalizeAggregate())
+ .matchesFromRoot(
+ logicalProject(
+ logicalAggregate(
+ logicalOlapScan()
+
).when(FieldChecker.check("groupByExpressions", ImmutableList.of(key)))
+ .when(aggregate ->
aggregate.getOutputExpressions().size() == 1)
+ .when(aggregate ->
aggregate.getOutputExpressions().get(0).child(0).equals(aggregateFunction))
+ ).when(project -> project.getProjects().size() == 1)
+ .when(project -> project.getProjects().get(0)
instanceof Alias)
+ .when(project ->
project.getProjects().get(0).getExprId().equals(output.getExprId()))
+ .when(project ->
project.getProjects().get(0).child(0) instanceof Add)
+ .when(project ->
project.getProjects().get(0).child(0).child(0) instanceof SlotReference)
+ .when(project ->
project.getProjects().get(0).child(0).child(1).equals(new IntegerLiteral(2)))
+ );
+ }
+
+
+ /**
+ * original plan:
+ * LogicalAggregate (phase: [GLOBAL], output: [((gender#1 + 1) + 2) AS
`((gender + 1) + 2)`#4], groupBy: [(gender#1 + 1)])
+ * +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2,
age#3])
+ *
+ * after rewrite:
+ * LogicalProject (((gender + 1)#5 + 2) AS `((gender + 1) + 2)`#4)
+ * +--LogicalAggregate (phase: [GLOBAL], output: [(gender + 1)#5],
groupBy: [(gender + 1)#5])
+ * +--LogicalProject ((gender#1 + 1) AS `(gender + 1)`#5)
+ * +--ScanOlapTable (student.student, output: [id#0, gender#1,
name#2, age#3])
+ */
+ @Test
+ public void testComplexKeyWithComplexOutputOfKey() {
+ Expression key = new Add(rStudent.getOutput().get(1).toSlot(), new
IntegerLiteral(1));
+ Expression complexKeyOutput = new Add(key, new IntegerLiteral(2));
+ NamedExpression keyOutput = new Alias(complexKeyOutput,
complexKeyOutput.toSql());
+ List<Expression> groupExpressionList = Lists.newArrayList(key);
+ List<NamedExpression> outputExpressionList =
Lists.newArrayList(keyOutput);
+ Plan root = new LogicalAggregate<>(groupExpressionList,
outputExpressionList, rStudent);
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+ .applyTopDown(new NormalizeAggregate())
+ .matchesFromRoot(
+ logicalProject(
+ logicalAggregate(
+ logicalProject(
+ logicalOlapScan()
+ ).when(project ->
project.getProjects().size() == 1)
+ .when(project ->
project.getProjects().get(0) instanceof Alias)
+ .when(project ->
project.getProjects().get(0).child(0).equals(key))
+ ).when(aggregate ->
aggregate.getGroupByExpressions().get(0) instanceof SlotReference)
+ .when(aggregate ->
aggregate.getOutputExpressions().get(0) instanceof SlotReference)
+ .when(aggregate ->
aggregate.getGroupByExpressions().equals(aggregate.getOutputExpressions()))
+ ).when(project ->
project.getProjects().get(0).getExprId().equals(keyOutput.getExprId()))
+ .when(project ->
project.getProjects().get(0).child(0) instanceof Add)
+ .when(project ->
project.getProjects().get(0).child(0).child(0) instanceof SlotReference)
+ .when(project ->
project.getProjects().get(0).child(0).child(1).equals(new IntegerLiteral(2)))
+
+ );
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]