This is an automated email from the ASF dual-hosted git repository.

hyuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/master by this push:
     new 8a1535f  [CALCITE-4652] AggregateExpandDistinctAggregatesRule must 
cast top aggregates to original type (Taras Ledkov)
8a1535f is described below

commit 8a1535f94aad1e0ce77797eb84d75b4a5b1692c1
Author: tledkov <tled...@gridgain.com>
AuthorDate: Fri Jun 4 17:54:17 2021 +0300

    [CALCITE-4652] AggregateExpandDistinctAggregatesRule must cast top 
aggregates to original type (Taras Ledkov)
    
    Close #2439
---
 .../AggregateExpandDistinctAggregatesRule.java     | 12 ++++-
 .../org/apache/calcite/test/RelOptRulesTest.java   | 44 ++++++++++++++++
 .../org/apache/calcite/test/SqlToRelTestBase.java  | 58 +++++++++++++---------
 .../org/apache/calcite/test/RelOptRulesTest.xml    | 21 ++++++++
 4 files changed, 111 insertions(+), 24 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
 
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
index 6ef9dae..cec3e58 100644
--- 
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
+++ 
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
@@ -26,6 +26,7 @@ import org.apache.calcite.rel.core.AggregateCall;
 import org.apache.calcite.rel.core.JoinRelType;
 import org.apache.calcite.rel.core.RelFactories;
 import org.apache.calcite.rel.logical.LogicalAggregate;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
 import org.apache.calcite.rel.type.RelDataTypeField;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexInputRef;
@@ -366,12 +367,15 @@ public final class AggregateExpandDistinctAggregatesRule
         final int arg = bottomGroups.size() + nonDistinctAggCallProcessedSoFar;
         final List<Integer> newArgs = ImmutableList.of(arg);
         if (aggCall.getAggregation().getKind() == SqlKind.COUNT) {
+          RelDataTypeFactory typeFactory = 
aggregate.getCluster().getTypeFactory();
+
           newCall =
               AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), false,
                   aggCall.isApproximate(), aggCall.ignoreNulls(),
                   newArgs, -1, aggCall.distinctKeys, aggCall.collation,
                   originalGroupSet.cardinality(), relBuilder.peek(),
-                  aggCall.getType(), aggCall.getName());
+                  typeFactory.getTypeSystem().deriveSumType(typeFactory, 
aggCall.getType()),
+                  aggCall.getName());
         } else {
           newCall =
               AggregateCall.create(aggCall.getAggregation(), false,
@@ -400,6 +404,12 @@ public final class AggregateExpandDistinctAggregatesRule
     relBuilder.push(
         aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
             ImmutableBitSet.of(topGroupSet), null, topAggregateCalls));
+
+    // Add projection node for case: SUM of COUNT(*):
+    // Type of the SUM may be larger than type of COUNT.
+    // CAST to original type must be added.
+    relBuilder.convert(aggregate.getRowType(), true);
+
     return relBuilder;
   }
 
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java 
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index d2b53c1..809289b 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -89,6 +89,7 @@ import org.apache.calcite.rel.rules.UnionMergeRule;
 import org.apache.calcite.rel.rules.ValuesReduceRule;
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rel.type.RelDataTypeSystemImpl;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexCall;
 import org.apache.calcite.rex.RexInputRef;
@@ -107,6 +108,7 @@ import org.apache.calcite.sql.fun.SqlLibrary;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.sql.type.OperandTypes;
 import org.apache.calcite.sql.type.ReturnTypes;
+import org.apache.calcite.sql.type.SqlTypeFactoryImpl;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.sql.validate.SqlConformanceEnum;
 import org.apache.calcite.sql.validate.SqlMonotonicity;
@@ -136,6 +138,7 @@ import java.util.List;
 import java.util.Locale;
 import java.util.function.Function;
 import java.util.function.Predicate;
+import java.util.function.Supplier;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertNotNull;
@@ -6769,4 +6772,45 @@ class RelOptRulesTest extends RelOptTestBase {
       relFn(relFn).with(hepPlanner).checkUnchanged();
     }
   }
+
+  /**
+   * Test case for <a 
href="https://issues.apache.org/jira/browse/CALCITE-4652";>[CALCITE-4652]
+   * AggregateExpandDistinctAggregatesRule must cast top aggregates to 
original type</a>.
+   * <p>
+   * Checks AggregateExpandDistinctAggregatesRule when return type of the SUM 
aggregate
+   * is changed (expanded) by define custom type factory.
+   */
+  @Test void testDistinctCountWithExpandSumType() {
+    // Define new type system to expand SUM return type.
+    RelDataTypeSystemImpl typeSystem = new RelDataTypeSystemImpl() {
+      @Override public RelDataType deriveSumType(RelDataTypeFactory 
typeFactory,
+          RelDataType argumentType) {
+        switch (argumentType.getSqlTypeName()) {
+        case INTEGER:
+        case BIGINT:
+          return typeFactory.createSqlType(SqlTypeName.DECIMAL);
+
+        default:
+          return super.deriveSumType(typeFactory, argumentType);
+        }
+      }
+    };
+
+    Supplier<RelDataTypeFactory> typeFactorySupplier = () -> new 
SqlTypeFactoryImpl(typeSystem);
+
+    // Expected plan:
+    // LogicalProject(EXPR$0=[CAST($0):BIGINT NOT NULL], EXPR$1=[$1])
+    //   LogicalAggregate(group=[{}], EXPR$0=[$SUM0($1)], EXPR$1=[COUNT($0)])
+    //     LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])
+    //       LogicalProject(COMM=[$6])
+    //         LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    //
+    // The top 'LogicalProject' must be added in case SUM type is expanded
+    // because type of original expression 'COUNT(DISTINCT comm)' is BIGINT
+    // and type of SUM (of BIGINT) is DECIMAL.
+    sql("SELECT count(comm), COUNT(DISTINCT comm) FROM emp")
+        .withTester(t -> t.withTypeFactorySupplier(typeFactorySupplier))
+        .withRule(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN)
+        .check();
+  }
 }
diff --git a/core/src/test/java/org/apache/calcite/test/SqlToRelTestBase.java 
b/core/src/test/java/org/apache/calcite/test/SqlToRelTestBase.java
index 93921f1..85d581b 100644
--- a/core/src/test/java/org/apache/calcite/test/SqlToRelTestBase.java
+++ b/core/src/test/java/org/apache/calcite/test/SqlToRelTestBase.java
@@ -74,6 +74,7 @@ import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.calcite.util.TestUtil;
 
+import com.google.common.base.Suppliers;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Iterables;
 
@@ -81,6 +82,7 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.Objects;
 import java.util.function.Function;
+import java.util.function.Supplier;
 import java.util.function.UnaryOperator;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -100,6 +102,8 @@ public abstract class SqlToRelTestBase {
   //~ Static fields/initializers ---------------------------------------------
 
   protected static final String NL = System.getProperty("line.separator");
+  protected static final Supplier<RelDataTypeFactory> 
DEFAULT_TYPE_FACTORY_SUPPLIER =
+      Suppliers.memoize(() -> new 
SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT));
 
   //~ Instance fields --------------------------------------------------------
 
@@ -111,7 +115,7 @@ public abstract class SqlToRelTestBase {
     final TesterImpl tester =
         new TesterImpl(getDiffRepos(), false, false, false, true, null, null,
             MockRelOptPlanner::new, UnaryOperator.identity(),
-            SqlConformanceEnum.DEFAULT, UnaryOperator.identity());
+            SqlConformanceEnum.DEFAULT, UnaryOperator.identity(), 
DEFAULT_TYPE_FACTORY_SUPPLIER);
     return tester.withConfig(c ->
         c.withTrimUnusedFields(true)
             .withExpand(true)
@@ -287,6 +291,9 @@ public abstract class SqlToRelTestBase {
     /** Returns a tester that uses a given context. */
     Tester withContext(UnaryOperator<Context> transform);
 
+    /** Returns a tester that uses a type factory. */
+    Tester withTypeFactorySupplier(Supplier<RelDataTypeFactory> 
typeFactorySupplier);
+
     /** Trims a RelNode. */
     RelNode trimRelNode(RelNode relNode);
 
@@ -564,7 +571,7 @@ public abstract class SqlToRelTestBase {
     private final SqlConformance conformance;
     private final SqlTestFactory.MockCatalogReaderFactory catalogReaderFactory;
     private final Function<RelOptCluster, RelOptCluster> clusterFactory;
-    private RelDataTypeFactory typeFactory;
+    private final Supplier<RelDataTypeFactory> typeFactorySupplier;
     private final UnaryOperator<SqlToRelConverter.Config> configTransform;
     private final UnaryOperator<Context> contextTransform;
 
@@ -572,7 +579,7 @@ public abstract class SqlToRelTestBase {
     protected TesterImpl(DiffRepository diffRepos) {
       this(diffRepos, true, true, false, true, null, null,
           MockRelOptPlanner::new, UnaryOperator.identity(),
-          SqlConformanceEnum.DEFAULT, c -> Contexts.empty());
+          SqlConformanceEnum.DEFAULT, c -> Contexts.empty(), 
DEFAULT_TYPE_FACTORY_SUPPLIER);
     }
 
     /**
@@ -591,7 +598,8 @@ public abstract class SqlToRelTestBase {
         Function<RelOptCluster, RelOptCluster> clusterFactory,
         Function<Context, RelOptPlanner> plannerFactory,
         UnaryOperator<SqlToRelConverter.Config> configTransform,
-        SqlConformance conformance, UnaryOperator<Context> contextTransform) {
+        SqlConformance conformance, UnaryOperator<Context> contextTransform,
+        Supplier<RelDataTypeFactory> typeFactorySupplier) {
       this.diffRepos = diffRepos;
       this.enableDecorrelate = enableDecorrelate;
       this.enableTrim = enableTrim;
@@ -603,6 +611,7 @@ public abstract class SqlToRelTestBase {
       this.plannerFactory = Objects.requireNonNull(plannerFactory, 
"plannerFactory");
       this.conformance = Objects.requireNonNull(conformance, "conformance");
       this.contextTransform = Objects.requireNonNull(contextTransform, 
"contextTransform");
+      this.typeFactorySupplier = Objects.requireNonNull(typeFactorySupplier, 
"typeFactorySupplier");
     }
 
     public RelRoot convertSqlToRel(String sql) {
@@ -667,7 +676,7 @@ public abstract class SqlToRelTestBase {
       return createSqlToRelConverter(
           validator,
           catalogReader,
-          typeFactory,
+          getTypeFactory(),
           config);
     }
 
@@ -689,14 +698,7 @@ public abstract class SqlToRelTestBase {
     }
 
     protected final RelDataTypeFactory getTypeFactory() {
-      if (typeFactory == null) {
-        typeFactory = createTypeFactory();
-      }
-      return typeFactory;
-    }
-
-    protected RelDataTypeFactory createTypeFactory() {
-      return new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
+      return typeFactorySupplier.get();
     }
 
     protected final RelOptPlanner getPlanner() {
@@ -899,7 +901,7 @@ public abstract class SqlToRelTestBase {
           : new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
               enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
               clusterFactory, plannerFactory, configTransform, conformance,
-              contextTransform);
+              contextTransform, typeFactorySupplier);
     }
 
     public TesterImpl withLateDecorrelation(boolean enableLateDecorrelate) {
@@ -908,7 +910,7 @@ public abstract class SqlToRelTestBase {
           : new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
               enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
               clusterFactory, plannerFactory, configTransform, conformance,
-              contextTransform);
+              contextTransform, typeFactorySupplier);
     }
 
     public Tester withConfig(UnaryOperator<SqlToRelConverter.Config> 
transform) {
@@ -917,7 +919,7 @@ public abstract class SqlToRelTestBase {
       return new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
           enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
           clusterFactory, plannerFactory, configTransform, conformance,
-          contextTransform);
+          contextTransform, typeFactorySupplier);
     }
 
     public TesterImpl withTrim(boolean enableTrim) {
@@ -926,21 +928,21 @@ public abstract class SqlToRelTestBase {
           : new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
               enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
               clusterFactory, plannerFactory, configTransform, conformance,
-              contextTransform);
+              contextTransform, typeFactorySupplier);
     }
 
     public TesterImpl withConformance(SqlConformance conformance) {
       return new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
           enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
           clusterFactory, plannerFactory, configTransform, conformance,
-          contextTransform);
+          contextTransform, typeFactorySupplier);
     }
 
     public Tester enableTypeCoercion(boolean enableTypeCoercion) {
       return new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
           enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
           clusterFactory, plannerFactory, configTransform, conformance,
-          contextTransform);
+          contextTransform, typeFactorySupplier);
     }
 
     public Tester withCatalogReaderFactory(
@@ -948,7 +950,7 @@ public abstract class SqlToRelTestBase {
       return new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
           enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
           clusterFactory, plannerFactory, configTransform, conformance,
-          contextTransform);
+          contextTransform, typeFactorySupplier);
     }
 
     public Tester withClusterFactory(
@@ -956,7 +958,7 @@ public abstract class SqlToRelTestBase {
       return new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
           enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
           clusterFactory, plannerFactory, configTransform, conformance,
-          contextTransform);
+          contextTransform, typeFactorySupplier);
     }
 
     public Tester withPlannerFactory(
@@ -966,14 +968,24 @@ public abstract class SqlToRelTestBase {
           : new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
               enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
               clusterFactory, plannerFactory, configTransform, conformance,
-              contextTransform);
+              contextTransform, typeFactorySupplier);
+    }
+
+    public Tester withTypeFactorySupplier(
+        Supplier<RelDataTypeFactory> typeFactorySupplier) {
+      return this.typeFactorySupplier == typeFactorySupplier
+          ? this
+          : new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
+              enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
+              clusterFactory, plannerFactory, configTransform, conformance,
+              contextTransform, typeFactorySupplier);
     }
 
     public TesterImpl withContext(UnaryOperator<Context> context) {
       return new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
           enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
           clusterFactory, plannerFactory, configTransform, conformance,
-          context);
+          context, typeFactorySupplier);
     }
 
     public boolean isLateDecorrelate() {
diff --git 
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml 
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index f811fba..3ed9020 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -2369,6 +2369,27 @@ LogicalProject(DEPTNO=[$0], EXPR$1=[$3], EXPR$2=[$5], 
EXPR$3=[$7], EXPR$4=[$1])
 ]]>
     </Resource>
   </TestCase>
+  <TestCase name="testDistinctCountWithExpandSumType">
+    <Resource name="sql">
+      <![CDATA[SELECT count(comm), COUNT(DISTINCT comm) FROM emp]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[COUNT()], EXPR$1=[COUNT(DISTINCT $0)])
+  LogicalProject(COMM=[$6])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalProject(EXPR$0=[CAST($0):BIGINT NOT NULL], EXPR$1=[$1])
+  LogicalAggregate(group=[{}], EXPR$0=[$SUM0($1)], EXPR$1=[COUNT($0)])
+    LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])
+      LogicalProject(COMM=[$6])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
   <TestCase name="testDistinctCountWithoutGroupBy">
     <Resource name="sql">
       <![CDATA[select max(deptno), count(distinct ename)

Reply via email to