This is an automated email from the ASF dual-hosted git repository.

mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 022d878a73 [CALCITE-6261] AssertionError with field pruning & 
duplicate agg calls
022d878a73 is described below

commit 022d878a73dec796bb72743804a6dded7c239bd3
Author: Niels Pardon <p...@zurich.ibm.com>
AuthorDate: Fri Feb 23 11:04:07 2024 +0100

    [CALCITE-6261] AssertionError with field pruning & duplicate agg calls
    
    Signed-off-by: Niels Pardon <p...@zurich.ibm.com>
---
 .../java/org/apache/calcite/tools/RelBuilder.java  | 55 +++++++++++--------
 .../org/apache/calcite/test/RelBuilderTest.java    | 61 ++++++++++++++++++++++
 2 files changed, 95 insertions(+), 21 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java 
b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
index ad58474bae..757e81265f 100644
--- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
+++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
@@ -2442,9 +2442,22 @@ public class RelBuilder {
       assert groupSet.contains(set);
     }
 
-    PairList<ImmutableSet<String>, RelDataTypeField> inFields = frame.fields;
-    final ImmutableBitSet groupSet2;
-    final ImmutableList<ImmutableBitSet> groupSets2;
+    return pruneAggregateInputFieldsAndDeduplicateAggCalls(r, groupSet, 
groupSets, aggregateCalls,
+        frame.fields, registrar.extraNodes);
+  }
+
+  /**
+   * Prunes unused fields on the input of the aggregate and removes duplicate 
aggregation calls.
+   */
+  private RelBuilder pruneAggregateInputFieldsAndDeduplicateAggCalls(
+      RelNode r,
+      final ImmutableBitSet groupSet,
+      final ImmutableList<ImmutableBitSet> groupSets,
+      final List<AggregateCall> aggregateCalls,
+      PairList<ImmutableSet<String>, RelDataTypeField> inFields,
+      final List<RexNode> extraNodes) {
+    final ImmutableBitSet groupSetAfterPruning;
+    final ImmutableList<ImmutableBitSet> groupSetsAfterPruning;
     if (config.pruneInputOfAggregate()
         && r instanceof Project) {
       final Set<Integer> fieldsUsed =
@@ -2453,22 +2466,22 @@ public class RelBuilder {
       // pretend that one field is used.
       if (fieldsUsed.isEmpty()) {
         r = ((Project) r).getInput();
-        groupSet2 = groupSet;
-        groupSets2 = groupSets;
+        groupSetAfterPruning = groupSet;
+        groupSetsAfterPruning = groupSets;
       } else if (fieldsUsed.size() < r.getRowType().getFieldCount()) {
         // Some fields are computed but not used. Prune them.
-        final Map<Integer, Integer> map = new HashMap<>();
+        final Map<Integer, Integer> sourceFieldToTargetFieldMap = new 
HashMap<>();
         for (int source : fieldsUsed) {
-          map.put(source, map.size());
+          sourceFieldToTargetFieldMap.put(source, 
sourceFieldToTargetFieldMap.size());
         }
 
-        groupSet2 = groupSet.permute(map);
-        groupSets2 =
+        groupSetAfterPruning = groupSet.permute(sourceFieldToTargetFieldMap);
+        groupSetsAfterPruning =
             ImmutableBitSet.ORDERING.immutableSortedCopy(
-                ImmutableBitSet.permute(groupSets, map));
+                ImmutableBitSet.permute(groupSets, 
sourceFieldToTargetFieldMap));
 
         final Mappings.TargetMapping targetMapping =
-            Mappings.target(map, r.getRowType().getFieldCount(),
+            Mappings.target(sourceFieldToTargetFieldMap, 
r.getRowType().getFieldCount(),
                 fieldsUsed.size());
         final List<AggregateCall> oldAggregateCalls =
             new ArrayList<>(aggregateCalls);
@@ -2493,24 +2506,24 @@ public class RelBuilder {
             project.copy(cluster.traitSet(), project.getInput(), newProjects,
                 builder.build());
       } else {
-        groupSet2 = groupSet;
-        groupSets2 = groupSets;
+        groupSetAfterPruning = groupSet;
+        groupSetsAfterPruning = groupSets;
       }
     } else {
-      groupSet2 = groupSet;
-      groupSets2 = groupSets;
+      groupSetAfterPruning = groupSet;
+      groupSetsAfterPruning = groupSets;
     }
 
     if (!config.dedupAggregateCalls() || Util.isDistinct(aggregateCalls)) {
-      return aggregate_(groupSet2, groupSets2, r, aggregateCalls,
-          registrar.extraNodes, inFields);
+      return aggregate_(groupSetAfterPruning, groupSetsAfterPruning, r, 
aggregateCalls,
+          extraNodes, inFields);
     }
 
     // There are duplicate aggregate calls. Rebuild the list to eliminate
     // duplicates, then add a Project.
     final Set<AggregateCall> callSet = new HashSet<>();
     final PairList<Integer, @Nullable String> projects = PairList.of();
-    Util.range(groupSet.cardinality())
+    Util.range(groupSetAfterPruning.cardinality())
         .forEach(i -> projects.add(i, null));
     final List<AggregateCall> distinctAggregateCalls = new ArrayList<>();
     for (AggregateCall aggregateCall : aggregateCalls) {
@@ -2522,10 +2535,10 @@ public class RelBuilder {
         i = distinctAggregateCalls.indexOf(aggregateCall);
         assert i >= 0;
       }
-      projects.add(groupSet.cardinality() + i, aggregateCall.name);
+      projects.add(groupSetAfterPruning.cardinality() + i, aggregateCall.name);
     }
-    aggregate_(groupSet, groupSets, r, distinctAggregateCalls,
-        registrar.extraNodes, inFields);
+    aggregate_(groupSetAfterPruning, groupSetsAfterPruning, r, 
distinctAggregateCalls,
+        extraNodes, inFields);
     return project(projects.transform((i, name) -> aliasMaybe(field(i), 
name)));
   }
 
diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java 
b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
index 953196deb9..7f8f42c51b 100644
--- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
@@ -1490,6 +1490,67 @@ public class RelBuilderTest {
     assertThat(root, hasTree(expected));
   }
 
+  /**
+   * Test reproducing issue CALCITE-6261.
+   */
+  @Test void testAggregateDuplicateAggCallsWithForceProjectAndFieldPruning() {
+    final Function<RelBuilder, RelNode> f1 = builder ->
+        // single table scan with force project of all columns
+        builder.scan("EMP")
+            .project(
+                ImmutableList.of(
+                    builder.field("EMPNO"),
+                    builder.field("ENAME"),
+                    builder.field("JOB"),
+                    builder.field("MGR"),
+                    builder.field("HIREDATE"),
+                    builder.field("SAL"),
+                    builder.field("COMM"),
+                    builder.field("DEPTNO")),
+                ImmutableList.of(),
+                true)
+            .aggregate(
+                builder.groupKey(builder.field("MGR")),
+                // duplicate avg() agg calls
+                builder.avg(false, "SALARY_AVG", builder.field("SAL")),
+                builder.sum(false, "SALARY_SUM", builder.field("SAL")),
+                builder.avg(false, "SALARY_MEAN", builder.field("SAL")))
+            .build();
+    final String expected = ""
+        + "LogicalProject(MGR=[$0], SALARY_AVG=[$1], SALARY_SUM=[$2], 
SALARY_MEAN=[$1])\n"
+        + "  LogicalAggregate(group=[{0}], SALARY_AVG=[AVG($1)], 
SALARY_SUM=[SUM($1)])\n"
+        + "    LogicalProject(MGR=[$3], SAL=[$5])\n"
+        + "      LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(f1.apply(createBuilder()), hasTree(expected));
+  }
+
+  /**
+   * Test recreating the reproducer from issue CALCITE-5888 but with the 
existing scott tables.
+   */
+  @Test void 
testAggregateDuplicateAggCallsAndFieldPruningWithJoinAndLiteralGroupKey() {
+    final Function<RelBuilder, RelNode> f1 = builder ->
+        // first inner join two tables
+        builder.scan("EMP").scan("DEPT")
+            .join(JoinRelType.INNER, "DEPTNO")
+            .aggregate(
+                // null group key
+                builder.groupKey(builder.cast(builder.literal(null), 
SqlTypeName.INTEGER)),
+                // duplicated min/max agg calls
+                builder.min(builder.field("SAL")),
+                builder.max(builder.field("SAL")),
+                builder.min(builder.field("SAL")),
+                builder.max(builder.field("SAL")))
+            .build();
+    final String expected = ""
+        + "LogicalProject($f11=[$0], $f1=[$1], $f2=[$2], $f10=[$1], 
$f20=[$2])\n"
+        + "  LogicalAggregate(group=[{1}], agg#0=[MIN($0)], agg#1=[MAX($0)])\n"
+        + "    LogicalProject(SAL=[$5], $f11=[null:INTEGER])\n"
+        + "      LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n"
+        + "        LogicalTableScan(table=[[scott, EMP]])\n"
+        + "        LogicalTableScan(table=[[scott, DEPT]])\n";
+    assertThat(f1.apply(createBuilder()), hasTree(expected));
+  }
+
   @Test void testAggregateFilter() {
     // Equivalent SQL:
     //   SELECT deptno, COUNT(*) FILTER (WHERE empno > 100) AS c

Reply via email to