This is an automated email from the ASF dual-hosted git repository. mbudiu pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push: new 022d878a73 [CALCITE-6261] AssertionError with field pruning & duplicate agg calls 022d878a73 is described below commit 022d878a73dec796bb72743804a6dded7c239bd3 Author: Niels Pardon <p...@zurich.ibm.com> AuthorDate: Fri Feb 23 11:04:07 2024 +0100 [CALCITE-6261] AssertionError with field pruning & duplicate agg calls Signed-off-by: Niels Pardon <p...@zurich.ibm.com> --- .../java/org/apache/calcite/tools/RelBuilder.java | 55 +++++++++++-------- .../org/apache/calcite/test/RelBuilderTest.java | 61 ++++++++++++++++++++++ 2 files changed, 95 insertions(+), 21 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java index ad58474bae..757e81265f 100644 --- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java +++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java @@ -2442,9 +2442,22 @@ public class RelBuilder { assert groupSet.contains(set); } - PairList<ImmutableSet<String>, RelDataTypeField> inFields = frame.fields; - final ImmutableBitSet groupSet2; - final ImmutableList<ImmutableBitSet> groupSets2; + return pruneAggregateInputFieldsAndDeduplicateAggCalls(r, groupSet, groupSets, aggregateCalls, + frame.fields, registrar.extraNodes); + } + + /** + * Prunes unused fields on the input of the aggregate and removes duplicate aggregation calls. + */ + private RelBuilder pruneAggregateInputFieldsAndDeduplicateAggCalls( + RelNode r, + final ImmutableBitSet groupSet, + final ImmutableList<ImmutableBitSet> groupSets, + final List<AggregateCall> aggregateCalls, + PairList<ImmutableSet<String>, RelDataTypeField> inFields, + final List<RexNode> extraNodes) { + final ImmutableBitSet groupSetAfterPruning; + final ImmutableList<ImmutableBitSet> groupSetsAfterPruning; if (config.pruneInputOfAggregate() && r instanceof Project) { final Set<Integer> fieldsUsed = @@ -2453,22 +2466,22 @@ public class RelBuilder { // pretend that one field is used. if (fieldsUsed.isEmpty()) { r = ((Project) r).getInput(); - groupSet2 = groupSet; - groupSets2 = groupSets; + groupSetAfterPruning = groupSet; + groupSetsAfterPruning = groupSets; } else if (fieldsUsed.size() < r.getRowType().getFieldCount()) { // Some fields are computed but not used. Prune them. - final Map<Integer, Integer> map = new HashMap<>(); + final Map<Integer, Integer> sourceFieldToTargetFieldMap = new HashMap<>(); for (int source : fieldsUsed) { - map.put(source, map.size()); + sourceFieldToTargetFieldMap.put(source, sourceFieldToTargetFieldMap.size()); } - groupSet2 = groupSet.permute(map); - groupSets2 = + groupSetAfterPruning = groupSet.permute(sourceFieldToTargetFieldMap); + groupSetsAfterPruning = ImmutableBitSet.ORDERING.immutableSortedCopy( - ImmutableBitSet.permute(groupSets, map)); + ImmutableBitSet.permute(groupSets, sourceFieldToTargetFieldMap)); final Mappings.TargetMapping targetMapping = - Mappings.target(map, r.getRowType().getFieldCount(), + Mappings.target(sourceFieldToTargetFieldMap, r.getRowType().getFieldCount(), fieldsUsed.size()); final List<AggregateCall> oldAggregateCalls = new ArrayList<>(aggregateCalls); @@ -2493,24 +2506,24 @@ public class RelBuilder { project.copy(cluster.traitSet(), project.getInput(), newProjects, builder.build()); } else { - groupSet2 = groupSet; - groupSets2 = groupSets; + groupSetAfterPruning = groupSet; + groupSetsAfterPruning = groupSets; } } else { - groupSet2 = groupSet; - groupSets2 = groupSets; + groupSetAfterPruning = groupSet; + groupSetsAfterPruning = groupSets; } if (!config.dedupAggregateCalls() || Util.isDistinct(aggregateCalls)) { - return aggregate_(groupSet2, groupSets2, r, aggregateCalls, - registrar.extraNodes, inFields); + return aggregate_(groupSetAfterPruning, groupSetsAfterPruning, r, aggregateCalls, + extraNodes, inFields); } // There are duplicate aggregate calls. Rebuild the list to eliminate // duplicates, then add a Project. final Set<AggregateCall> callSet = new HashSet<>(); final PairList<Integer, @Nullable String> projects = PairList.of(); - Util.range(groupSet.cardinality()) + Util.range(groupSetAfterPruning.cardinality()) .forEach(i -> projects.add(i, null)); final List<AggregateCall> distinctAggregateCalls = new ArrayList<>(); for (AggregateCall aggregateCall : aggregateCalls) { @@ -2522,10 +2535,10 @@ public class RelBuilder { i = distinctAggregateCalls.indexOf(aggregateCall); assert i >= 0; } - projects.add(groupSet.cardinality() + i, aggregateCall.name); + projects.add(groupSetAfterPruning.cardinality() + i, aggregateCall.name); } - aggregate_(groupSet, groupSets, r, distinctAggregateCalls, - registrar.extraNodes, inFields); + aggregate_(groupSetAfterPruning, groupSetsAfterPruning, r, distinctAggregateCalls, + extraNodes, inFields); return project(projects.transform((i, name) -> aliasMaybe(field(i), name))); } diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java index 953196deb9..7f8f42c51b 100644 --- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java @@ -1490,6 +1490,67 @@ public class RelBuilderTest { assertThat(root, hasTree(expected)); } + /** + * Test reproducing issue CALCITE-6261. + */ + @Test void testAggregateDuplicateAggCallsWithForceProjectAndFieldPruning() { + final Function<RelBuilder, RelNode> f1 = builder -> + // single table scan with force project of all columns + builder.scan("EMP") + .project( + ImmutableList.of( + builder.field("EMPNO"), + builder.field("ENAME"), + builder.field("JOB"), + builder.field("MGR"), + builder.field("HIREDATE"), + builder.field("SAL"), + builder.field("COMM"), + builder.field("DEPTNO")), + ImmutableList.of(), + true) + .aggregate( + builder.groupKey(builder.field("MGR")), + // duplicate avg() agg calls + builder.avg(false, "SALARY_AVG", builder.field("SAL")), + builder.sum(false, "SALARY_SUM", builder.field("SAL")), + builder.avg(false, "SALARY_MEAN", builder.field("SAL"))) + .build(); + final String expected = "" + + "LogicalProject(MGR=[$0], SALARY_AVG=[$1], SALARY_SUM=[$2], SALARY_MEAN=[$1])\n" + + " LogicalAggregate(group=[{0}], SALARY_AVG=[AVG($1)], SALARY_SUM=[SUM($1)])\n" + + " LogicalProject(MGR=[$3], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + assertThat(f1.apply(createBuilder()), hasTree(expected)); + } + + /** + * Test recreating the reproducer from issue CALCITE-5888 but with the existing scott tables. + */ + @Test void testAggregateDuplicateAggCallsAndFieldPruningWithJoinAndLiteralGroupKey() { + final Function<RelBuilder, RelNode> f1 = builder -> + // first inner join two tables + builder.scan("EMP").scan("DEPT") + .join(JoinRelType.INNER, "DEPTNO") + .aggregate( + // null group key + builder.groupKey(builder.cast(builder.literal(null), SqlTypeName.INTEGER)), + // duplicated min/max agg calls + builder.min(builder.field("SAL")), + builder.max(builder.field("SAL")), + builder.min(builder.field("SAL")), + builder.max(builder.field("SAL"))) + .build(); + final String expected = "" + + "LogicalProject($f11=[$0], $f1=[$1], $f2=[$2], $f10=[$1], $f20=[$2])\n" + + " LogicalAggregate(group=[{1}], agg#0=[MIN($0)], agg#1=[MAX($0)])\n" + + " LogicalProject(SAL=[$5], $f11=[null:INTEGER])\n" + + " LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + assertThat(f1.apply(createBuilder()), hasTree(expected)); + } + @Test void testAggregateFilter() { // Equivalent SQL: // SELECT deptno, COUNT(*) FILTER (WHERE empno > 100) AS c