This is an automated email from the ASF dual-hosted git repository. hyuan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/master by this push: new 8a1535f [CALCITE-4652] AggregateExpandDistinctAggregatesRule must cast top aggregates to original type (Taras Ledkov) 8a1535f is described below commit 8a1535f94aad1e0ce77797eb84d75b4a5b1692c1 Author: tledkov <tled...@gridgain.com> AuthorDate: Fri Jun 4 17:54:17 2021 +0300 [CALCITE-4652] AggregateExpandDistinctAggregatesRule must cast top aggregates to original type (Taras Ledkov) Close #2439 --- .../AggregateExpandDistinctAggregatesRule.java | 12 ++++- .../org/apache/calcite/test/RelOptRulesTest.java | 44 ++++++++++++++++ .../org/apache/calcite/test/SqlToRelTestBase.java | 58 +++++++++++++--------- .../org/apache/calcite/test/RelOptRulesTest.xml | 21 ++++++++ 4 files changed, 111 insertions(+), 24 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java index 6ef9dae..cec3e58 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java @@ -26,6 +26,7 @@ import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.RelFactories; import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexInputRef; @@ -366,12 +367,15 @@ public final class AggregateExpandDistinctAggregatesRule final int arg = bottomGroups.size() + nonDistinctAggCallProcessedSoFar; final List<Integer> newArgs = ImmutableList.of(arg); if (aggCall.getAggregation().getKind() == SqlKind.COUNT) { + RelDataTypeFactory typeFactory = aggregate.getCluster().getTypeFactory(); + newCall = AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), false, aggCall.isApproximate(), aggCall.ignoreNulls(), newArgs, -1, aggCall.distinctKeys, aggCall.collation, originalGroupSet.cardinality(), relBuilder.peek(), - aggCall.getType(), aggCall.getName()); + typeFactory.getTypeSystem().deriveSumType(typeFactory, aggCall.getType()), + aggCall.getName()); } else { newCall = AggregateCall.create(aggCall.getAggregation(), false, @@ -400,6 +404,12 @@ public final class AggregateExpandDistinctAggregatesRule relBuilder.push( aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), ImmutableBitSet.of(topGroupSet), null, topAggregateCalls)); + + // Add projection node for case: SUM of COUNT(*): + // Type of the SUM may be larger than type of COUNT. + // CAST to original type must be added. + relBuilder.convert(aggregate.getRowType(), true); + return relBuilder; } diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index d2b53c1..809289b 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -89,6 +89,7 @@ import org.apache.calcite.rel.rules.UnionMergeRule; import org.apache.calcite.rel.rules.ValuesReduceRule; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeSystemImpl; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexInputRef; @@ -107,6 +108,7 @@ import org.apache.calcite.sql.fun.SqlLibrary; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeFactoryImpl; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.validate.SqlConformanceEnum; import org.apache.calcite.sql.validate.SqlMonotonicity; @@ -136,6 +138,7 @@ import java.util.List; import java.util.Locale; import java.util.function.Function; import java.util.function.Predicate; +import java.util.function.Supplier; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -6769,4 +6772,45 @@ class RelOptRulesTest extends RelOptTestBase { relFn(relFn).with(hepPlanner).checkUnchanged(); } } + + /** + * Test case for <a href="https://issues.apache.org/jira/browse/CALCITE-4652">[CALCITE-4652] + * AggregateExpandDistinctAggregatesRule must cast top aggregates to original type</a>. + * <p> + * Checks AggregateExpandDistinctAggregatesRule when return type of the SUM aggregate + * is changed (expanded) by define custom type factory. + */ + @Test void testDistinctCountWithExpandSumType() { + // Define new type system to expand SUM return type. + RelDataTypeSystemImpl typeSystem = new RelDataTypeSystemImpl() { + @Override public RelDataType deriveSumType(RelDataTypeFactory typeFactory, + RelDataType argumentType) { + switch (argumentType.getSqlTypeName()) { + case INTEGER: + case BIGINT: + return typeFactory.createSqlType(SqlTypeName.DECIMAL); + + default: + return super.deriveSumType(typeFactory, argumentType); + } + } + }; + + Supplier<RelDataTypeFactory> typeFactorySupplier = () -> new SqlTypeFactoryImpl(typeSystem); + + // Expected plan: + // LogicalProject(EXPR$0=[CAST($0):BIGINT NOT NULL], EXPR$1=[$1]) + // LogicalAggregate(group=[{}], EXPR$0=[$SUM0($1)], EXPR$1=[COUNT($0)]) + // LogicalAggregate(group=[{0}], EXPR$0=[COUNT()]) + // LogicalProject(COMM=[$6]) + // LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + // + // The top 'LogicalProject' must be added in case SUM type is expanded + // because type of original expression 'COUNT(DISTINCT comm)' is BIGINT + // and type of SUM (of BIGINT) is DECIMAL. + sql("SELECT count(comm), COUNT(DISTINCT comm) FROM emp") + .withTester(t -> t.withTypeFactorySupplier(typeFactorySupplier)) + .withRule(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN) + .check(); + } } diff --git a/core/src/test/java/org/apache/calcite/test/SqlToRelTestBase.java b/core/src/test/java/org/apache/calcite/test/SqlToRelTestBase.java index 93921f1..85d581b 100644 --- a/core/src/test/java/org/apache/calcite/test/SqlToRelTestBase.java +++ b/core/src/test/java/org/apache/calcite/test/SqlToRelTestBase.java @@ -74,6 +74,7 @@ import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.TestUtil; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; @@ -81,6 +82,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.function.Function; +import java.util.function.Supplier; import java.util.function.UnaryOperator; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -100,6 +102,8 @@ public abstract class SqlToRelTestBase { //~ Static fields/initializers --------------------------------------------- protected static final String NL = System.getProperty("line.separator"); + protected static final Supplier<RelDataTypeFactory> DEFAULT_TYPE_FACTORY_SUPPLIER = + Suppliers.memoize(() -> new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT)); //~ Instance fields -------------------------------------------------------- @@ -111,7 +115,7 @@ public abstract class SqlToRelTestBase { final TesterImpl tester = new TesterImpl(getDiffRepos(), false, false, false, true, null, null, MockRelOptPlanner::new, UnaryOperator.identity(), - SqlConformanceEnum.DEFAULT, UnaryOperator.identity()); + SqlConformanceEnum.DEFAULT, UnaryOperator.identity(), DEFAULT_TYPE_FACTORY_SUPPLIER); return tester.withConfig(c -> c.withTrimUnusedFields(true) .withExpand(true) @@ -287,6 +291,9 @@ public abstract class SqlToRelTestBase { /** Returns a tester that uses a given context. */ Tester withContext(UnaryOperator<Context> transform); + /** Returns a tester that uses a type factory. */ + Tester withTypeFactorySupplier(Supplier<RelDataTypeFactory> typeFactorySupplier); + /** Trims a RelNode. */ RelNode trimRelNode(RelNode relNode); @@ -564,7 +571,7 @@ public abstract class SqlToRelTestBase { private final SqlConformance conformance; private final SqlTestFactory.MockCatalogReaderFactory catalogReaderFactory; private final Function<RelOptCluster, RelOptCluster> clusterFactory; - private RelDataTypeFactory typeFactory; + private final Supplier<RelDataTypeFactory> typeFactorySupplier; private final UnaryOperator<SqlToRelConverter.Config> configTransform; private final UnaryOperator<Context> contextTransform; @@ -572,7 +579,7 @@ public abstract class SqlToRelTestBase { protected TesterImpl(DiffRepository diffRepos) { this(diffRepos, true, true, false, true, null, null, MockRelOptPlanner::new, UnaryOperator.identity(), - SqlConformanceEnum.DEFAULT, c -> Contexts.empty()); + SqlConformanceEnum.DEFAULT, c -> Contexts.empty(), DEFAULT_TYPE_FACTORY_SUPPLIER); } /** @@ -591,7 +598,8 @@ public abstract class SqlToRelTestBase { Function<RelOptCluster, RelOptCluster> clusterFactory, Function<Context, RelOptPlanner> plannerFactory, UnaryOperator<SqlToRelConverter.Config> configTransform, - SqlConformance conformance, UnaryOperator<Context> contextTransform) { + SqlConformance conformance, UnaryOperator<Context> contextTransform, + Supplier<RelDataTypeFactory> typeFactorySupplier) { this.diffRepos = diffRepos; this.enableDecorrelate = enableDecorrelate; this.enableTrim = enableTrim; @@ -603,6 +611,7 @@ public abstract class SqlToRelTestBase { this.plannerFactory = Objects.requireNonNull(plannerFactory, "plannerFactory"); this.conformance = Objects.requireNonNull(conformance, "conformance"); this.contextTransform = Objects.requireNonNull(contextTransform, "contextTransform"); + this.typeFactorySupplier = Objects.requireNonNull(typeFactorySupplier, "typeFactorySupplier"); } public RelRoot convertSqlToRel(String sql) { @@ -667,7 +676,7 @@ public abstract class SqlToRelTestBase { return createSqlToRelConverter( validator, catalogReader, - typeFactory, + getTypeFactory(), config); } @@ -689,14 +698,7 @@ public abstract class SqlToRelTestBase { } protected final RelDataTypeFactory getTypeFactory() { - if (typeFactory == null) { - typeFactory = createTypeFactory(); - } - return typeFactory; - } - - protected RelDataTypeFactory createTypeFactory() { - return new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT); + return typeFactorySupplier.get(); } protected final RelOptPlanner getPlanner() { @@ -899,7 +901,7 @@ public abstract class SqlToRelTestBase { : new TesterImpl(diffRepos, enableDecorrelate, enableTrim, enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory, clusterFactory, plannerFactory, configTransform, conformance, - contextTransform); + contextTransform, typeFactorySupplier); } public TesterImpl withLateDecorrelation(boolean enableLateDecorrelate) { @@ -908,7 +910,7 @@ public abstract class SqlToRelTestBase { : new TesterImpl(diffRepos, enableDecorrelate, enableTrim, enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory, clusterFactory, plannerFactory, configTransform, conformance, - contextTransform); + contextTransform, typeFactorySupplier); } public Tester withConfig(UnaryOperator<SqlToRelConverter.Config> transform) { @@ -917,7 +919,7 @@ public abstract class SqlToRelTestBase { return new TesterImpl(diffRepos, enableDecorrelate, enableTrim, enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory, clusterFactory, plannerFactory, configTransform, conformance, - contextTransform); + contextTransform, typeFactorySupplier); } public TesterImpl withTrim(boolean enableTrim) { @@ -926,21 +928,21 @@ public abstract class SqlToRelTestBase { : new TesterImpl(diffRepos, enableDecorrelate, enableTrim, enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory, clusterFactory, plannerFactory, configTransform, conformance, - contextTransform); + contextTransform, typeFactorySupplier); } public TesterImpl withConformance(SqlConformance conformance) { return new TesterImpl(diffRepos, enableDecorrelate, enableTrim, enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory, clusterFactory, plannerFactory, configTransform, conformance, - contextTransform); + contextTransform, typeFactorySupplier); } public Tester enableTypeCoercion(boolean enableTypeCoercion) { return new TesterImpl(diffRepos, enableDecorrelate, enableTrim, enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory, clusterFactory, plannerFactory, configTransform, conformance, - contextTransform); + contextTransform, typeFactorySupplier); } public Tester withCatalogReaderFactory( @@ -948,7 +950,7 @@ public abstract class SqlToRelTestBase { return new TesterImpl(diffRepos, enableDecorrelate, enableTrim, enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory, clusterFactory, plannerFactory, configTransform, conformance, - contextTransform); + contextTransform, typeFactorySupplier); } public Tester withClusterFactory( @@ -956,7 +958,7 @@ public abstract class SqlToRelTestBase { return new TesterImpl(diffRepos, enableDecorrelate, enableTrim, enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory, clusterFactory, plannerFactory, configTransform, conformance, - contextTransform); + contextTransform, typeFactorySupplier); } public Tester withPlannerFactory( @@ -966,14 +968,24 @@ public abstract class SqlToRelTestBase { : new TesterImpl(diffRepos, enableDecorrelate, enableTrim, enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory, clusterFactory, plannerFactory, configTransform, conformance, - contextTransform); + contextTransform, typeFactorySupplier); + } + + public Tester withTypeFactorySupplier( + Supplier<RelDataTypeFactory> typeFactorySupplier) { + return this.typeFactorySupplier == typeFactorySupplier + ? this + : new TesterImpl(diffRepos, enableDecorrelate, enableTrim, + enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory, + clusterFactory, plannerFactory, configTransform, conformance, + contextTransform, typeFactorySupplier); } public TesterImpl withContext(UnaryOperator<Context> context) { return new TesterImpl(diffRepos, enableDecorrelate, enableTrim, enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory, clusterFactory, plannerFactory, configTransform, conformance, - context); + context, typeFactorySupplier); } public boolean isLateDecorrelate() { diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index f811fba..3ed9020 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -2369,6 +2369,27 @@ LogicalProject(DEPTNO=[$0], EXPR$1=[$3], EXPR$2=[$5], EXPR$3=[$7], EXPR$4=[$1]) ]]> </Resource> </TestCase> + <TestCase name="testDistinctCountWithExpandSumType"> + <Resource name="sql"> + <![CDATA[SELECT count(comm), COUNT(DISTINCT comm) FROM emp]]> + </Resource> + <Resource name="planBefore"> + <![CDATA[ +LogicalAggregate(group=[{}], EXPR$0=[COUNT()], EXPR$1=[COUNT(DISTINCT $0)]) + LogicalProject(COMM=[$6]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + <Resource name="planAfter"> + <![CDATA[ +LogicalProject(EXPR$0=[CAST($0):BIGINT NOT NULL], EXPR$1=[$1]) + LogicalAggregate(group=[{}], EXPR$0=[$SUM0($1)], EXPR$1=[COUNT($0)]) + LogicalAggregate(group=[{0}], EXPR$0=[COUNT()]) + LogicalProject(COMM=[$6]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + </Resource> + </TestCase> <TestCase name="testDistinctCountWithoutGroupBy"> <Resource name="sql"> <![CDATA[select max(deptno), count(distinct ename)