This is an automated email from the ASF dual-hosted git repository. korlov pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/ignite-3.git
The following commit(s) were added to refs/heads/main by this push: new 625aad225d IGNITE-20009: Sql. Rework 2-phase aggregates part 2. AVG as SUM / COUNT. (#3413) 625aad225d is described below commit 625aad225d2074cbf7e5419936532feb80d12cba Author: Max Zhuravkov <shh...@gmail.com> AuthorDate: Tue Mar 26 15:56:23 2024 +0200 IGNITE-20009: Sql. Rework 2-phase aggregates part 2. AVG as SUM / COUNT. (#3413) --- .../internal/sql/engine/ItAggregatesTest.java | 190 +++++++++- .../sql/engine/exec/LogicalRelImplementor.java | 5 +- .../sql/engine/exec/exp/IgniteSqlFunctions.java | 10 + .../internal/sql/engine/exec/exp/RexImpTable.java | 1 + .../sql/engine/rel/agg/MapReduceAggregates.java | 381 +++++++++++++++++++-- .../engine/rule/HashAggregateConverterRule.java | 24 +- .../engine/rule/SortAggregateConverterRule.java | 19 +- .../sql/engine/sql/fun/IgniteSqlOperatorTable.java | 27 ++ .../internal/sql/engine/util/IgniteMethod.java | 10 +- .../ignite/internal/sql/engine/util/PlanUtils.java | 12 +- .../engine/exec/exp/IgniteSqlFunctionsTest.java | 23 ++ .../exec/rel/HashAggregateExecutionTest.java | 8 +- .../rel/HashAggregateSingleGroupExecutionTest.java | 4 +- .../exec/rel/SortAggregateExecutionTest.java | 8 +- .../sql/engine/planner/AbstractPlannerTest.java | 4 + .../engine/planner/MapReduceAggregatesTest.java | 11 +- .../planner/MapReduceHashAggregatePlannerTest.java | 66 ++-- .../planner/MapReduceSortAggregatePlannerTest.java | 51 +-- .../internal/sql/engine/util/QueryChecker.java | 4 +- 19 files changed, 733 insertions(+), 125 deletions(-) diff --git a/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItAggregatesTest.java b/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItAggregatesTest.java index 745f3c8d75..f3e12c117a 100644 --- a/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItAggregatesTest.java +++ b/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItAggregatesTest.java @@ -21,10 +21,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.math.BigDecimal; +import java.math.MathContext; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Locale; -import java.util.stream.Collectors; +import java.util.Random; import java.util.stream.Stream; import org.apache.ignite.internal.sql.BaseSqlIntegrationTest; import org.apache.ignite.internal.sql.engine.hint.IgniteHint; @@ -49,9 +52,6 @@ public class ItAggregatesTest extends BaseSqlIntegrationTest { private static final List<String> MAP_REDUCE_RULES = List.of("MapReduceHashAggregateConverterRule", "MapReduceSortAggregateConverterRule"); - private static final List<String> COLO_RULES = Arrays.stream(DISABLED_RULES).filter(r -> !MAP_REDUCE_RULES.contains(r)) - .collect(Collectors.toList()); - private static final int ROWS = 103; /** @@ -88,6 +88,25 @@ public class ItAggregatesTest extends BaseSqlIntegrationTest { sql("CREATE TABLE test_str_int_real_dec " + "(id INTEGER PRIMARY KEY, str_col VARCHAR, int_col INTEGER, real_col REAL, dec_col DECIMAL)"); + + sql("CREATE TABLE IF NOT EXISTS numbers (" + + "id INTEGER PRIMARY KEY, " + + "tinyint_col TINYINT, " + + "smallint_col SMALLINT, " + + "int_col INTEGER, " + + "bigint_col BIGINT, " + + "float_col REAL, " + + "double_col DOUBLE, " + + "dec2_col DECIMAL(2), " + + "dec4_2_col DECIMAL(4,2), " + + "dec10_2_col DECIMAL(10,2) " + + ")"); + + sql("CREATE TABLE IF NOT EXISTS not_null_numbers (" + + "id INTEGER PRIMARY KEY, " + + "int_col INTEGER NOT NULL, " + + "dec4_2_col DECIMAL(4,2) NOT NULL" + + ")"); } @ParameterizedTest @@ -530,27 +549,166 @@ public class ItAggregatesTest extends BaseSqlIntegrationTest { @ParameterizedTest @MethodSource("provideRules") - public void testAvgOnEmptyGroup(String[] rules) { - sql("DELETE FROM test_str_int_real_dec"); + public void testAvg(String[] rules) { + sql("DELETE FROM numbers"); + sql("INSERT INTO numbers VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1, 1), (2, 2, 2, 2, 2, 2, 2, 2, 2, 2)"); + + assertQuery("SELECT " + + "AVG(tinyint_col), AVG(smallint_col), AVG(int_col), AVG(bigint_col), " + + "AVG(float_col), AVG(double_col), AVG(dec2_col), AVG(dec4_2_col) " + + "FROM numbers") + .disableRules(rules) + .returns((byte) 1, (short) 1, 1, 1L, 1.5f, 1.5d, new BigDecimal("1.5"), new BigDecimal("1.50")) + .check(); + + sql("DELETE FROM numbers"); + sql("INSERT INTO numbers (id, dec4_2_col) VALUES (1, 1), (2, 2)"); + + assertQuery("SELECT AVG(dec4_2_col) FROM numbers") + .disableRules(rules) + .returns(new BigDecimal("1.50")) + .check(); + + sql("DELETE FROM numbers"); + sql("INSERT INTO numbers (id, dec4_2_col) VALUES (1, 1), (2, 2.3333)"); + + assertQuery("SELECT AVG(dec4_2_col) FROM numbers") + .disableRules(rules) + .returns(new BigDecimal("1.665")) + .check(); + + sql("DELETE FROM numbers"); + sql("INSERT INTO numbers (id, int_col, dec4_2_col) VALUES (1, null, null)"); + + assertQuery("SELECT AVG(int_col), AVG(dec4_2_col) FROM numbers") + .disableRules(rules) + .returns(null, null) + .check(); + + sql("DELETE FROM numbers"); + sql("INSERT INTO numbers (id, int_col, dec4_2_col) VALUES (1, 1, 1), (2, null, null)"); + + assertQuery("SELECT AVG(int_col), AVG(dec4_2_col) FROM numbers") + .disableRules(rules) + .returns(1, new BigDecimal("1.00")) + .check(); + } + + @Test + public void testAvgRandom() { + long seed = System.nanoTime(); + Random random = new Random(seed); + + sql("DELETE FROM numbers"); + + List<BigDecimal> numbers = new ArrayList<>(); + log.info("Seed: {}", seed); + + for (int i = 1; i < 20; i++) { + int val = random.nextInt(100) + 1; + BigDecimal num = BigDecimal.valueOf(val); + numbers.add(num); + + String query = "INSERT INTO numbers (id, int_col, dec10_2_col) VALUES(?, ?, ?)"; + sql(query, i, num.intValue(), num); + } + + BigDecimal avg = numbers.stream() + .reduce(new BigDecimal("0.00"), BigDecimal::add) + .divide(BigDecimal.valueOf(numbers.size()), MathContext.DECIMAL64); + + for (String[] rules : makePermutations(DISABLED_RULES)) { + assertQuery("SELECT AVG(int_col), AVG(dec10_2_col) FROM numbers") + .disableRules(rules) + .returns(avg.intValue(), avg) + .check(); + } + } + + @ParameterizedTest + @MethodSource("provideRules") + public void testAvgNullNotNull(String[] rules) { + sql("DELETE FROM not_null_numbers"); + sql("INSERT INTO not_null_numbers (id, int_col, dec4_2_col) VALUES (1, 1, 1), (2, 2, 2)"); + + assertQuery("SELECT AVG(int_col), AVG(dec4_2_col) FROM not_null_numbers") + .disableRules(rules) + .returns(1, new BigDecimal("1.50")) + .check(); + + // Return type of an AVG aggregate can never be null. + assertQuery("SELECT AVG(int_col) FROM not_null_numbers GROUP BY int_col") + .disableRules(rules) + .returns(1) + .returns(2) + .check(); + + assertQuery("SELECT AVG(dec4_2_col) FROM not_null_numbers GROUP BY dec4_2_col") + .disableRules(rules) + .returns(new BigDecimal("1.00")) + .returns(new BigDecimal("2.00")) + .check(); - // TODO https://issues.apache.org/jira/browse/IGNITE-20009 - // Remove after is fixed. - Assumptions.assumeFalse(Arrays.stream(rules) - .filter(COLO_RULES::contains).count() == COLO_RULES.size(), "AVG is disabled for MAP/REDUCE"); + sql("DELETE FROM numbers"); + sql("INSERT INTO numbers (id, int_col, dec4_2_col) VALUES (1, 1, 1), (2, 2, 2)"); - assertQuery("SELECT AVG(int_col) FROM test_str_int_real_dec") + assertQuery("SELECT AVG(int_col) FROM numbers GROUP BY int_col") .disableRules(rules) - .returns(new Object[]{null}) + .returns(1) + .returns(2) .check(); - assertQuery("SELECT AVG(real_col) FROM test_str_int_real_dec") + assertQuery("SELECT AVG(dec4_2_col) FROM numbers GROUP BY dec4_2_col") .disableRules(rules) - .returns(new Object[]{null}) + .returns(new BigDecimal("1.00")) + .returns(new BigDecimal("2.00")) + .check(); + } + + @ParameterizedTest + @MethodSource("provideRules") + public void testAvgOnEmptyGroup(String[] rules) { + sql("DELETE FROM numbers"); + + assertQuery("SELECT " + + "AVG(tinyint_col), AVG(smallint_col), AVG(int_col), AVG(bigint_col), " + + "AVG(float_col), AVG(double_col), AVG(dec2_col), AVG(dec4_2_col) " + + "FROM numbers") + .disableRules(rules) + .returns(null, null, null, null, null, null, null, null) + .check(); + } + + @ParameterizedTest + @MethodSource("provideRules") + public void testAvgFromLiterals(String[] rules) { + + assertQuery("SELECT " + + "AVG(tinyint_col), AVG(smallint_col), AVG(int_col), AVG(bigint_col), " + + "AVG(float_col), AVG(double_col), AVG(dec2_col), AVG(dec4_2_col) " + + "FROM (VALUES " + + "(1::TINYINT, 1::SMALLINT, 1::INTEGER, 1::BIGINT, 1::REAL, 1::DOUBLE, 1::DECIMAL(2), 1.00::DECIMAL(4,2)), " + + "(2::TINYINT, 2::SMALLINT, 2::INTEGER, 2::BIGINT, 2::REAL, 2::DOUBLE, 2::DECIMAL(2), 2.00::DECIMAL(4,2)) " + + ") " + + "t(tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, dec2_col, dec4_2_col)") + .disableRules(rules) + .returns((byte) 1, (short) 1, 1, 1L, 1.5f, 1.5d, new BigDecimal("1.5"), new BigDecimal("1.50")) .check(); - assertQuery("SELECT AVG(dec_col) FROM test_str_int_real_dec") + assertQuery("SELECT " + + "AVG(1::TINYINT), AVG(2::SMALLINT), AVG(3::INTEGER), AVG(4::BIGINT), " + + "AVG(5::REAL), AVG(6::DOUBLE), AVG(7::DECIMAL(2)), AVG(8.00::DECIMAL(4,2))") .disableRules(rules) - .returns(new Object[]{null}) + .returns((byte) 1, (short) 2, 3, 4L, 5.0f, 6.0d, new BigDecimal("7"), new BigDecimal("8.00")) + .check(); + + assertQuery("SELECT AVG(dec2_col), AVG(dec4_2_col) FROM\n" + + "(SELECT \n" + + " 1::DECIMAL(2) as dec2_col, 2.00::DECIMAL(4, 2) as dec4_2_col\n" + + " UNION\n" + + " SELECT 2::DECIMAL(2) as dec2_col, 3.00::DECIMAL(4,2) as dec4_2_col\n" + + ") as t") + .returns(new BigDecimal("1.5"), new BigDecimal("2.50")) .check(); } diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/LogicalRelImplementor.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/LogicalRelImplementor.java index 6b44790bd7..5ebeedd8a5 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/LogicalRelImplementor.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/LogicalRelImplementor.java @@ -744,7 +744,10 @@ public class LogicalRelImplementor<RowT> implements IgniteRelVisitor<Node<RowT>> RelDataType rowType = rel.getRowType(); Supplier<List<AccumulatorWrapper<RowT>>> accFactory = expressionFactory.accumulatorsFactory( - type, rel.getAggregateCalls(), null); + type, + rel.getAggregateCalls(), + rel.getInput().getRowType() + ); RowSchema rowSchema = rowSchemaFromRelTypes(RelOptUtil.getFieldTypeList(rowType)); RowFactory<RowT> rowFactory = ctx.rowHandler().factory(rowSchema); diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java index 5f1f31d5d0..eb4639c45e 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java @@ -25,6 +25,7 @@ import static org.apache.ignite.lang.ErrorGroups.Sql.RUNTIME_ERR; import java.math.BigDecimal; import java.math.BigInteger; +import java.math.MathContext; import java.math.RoundingMode; import java.time.LocalDateTime; import java.time.LocalTime; @@ -51,6 +52,7 @@ import org.apache.calcite.schema.Statistic; import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.ignite.internal.sql.engine.sql.fun.IgniteSqlOperatorTable; import org.apache.ignite.internal.sql.engine.type.IgniteTypeSystem; import org.apache.ignite.internal.sql.engine.util.Commons; import org.apache.ignite.internal.sql.engine.util.TypeUtils; @@ -473,6 +475,14 @@ public class IgniteSqlFunctions { } } + /** + * Division function for REDUCE phase of AVG aggregate. Precision and scale is only used by type inference + * (see {@link IgniteSqlOperatorTable#DECIMAL_DIVIDE}, their values are ignored at runtime. + */ + public static BigDecimal decimalDivide(BigDecimal sum, BigDecimal cnt, int p, int s) { + return sum.divide(cnt, MathContext.DECIMAL64); + } + private static BigDecimal processValueWithIntegralPart(Number value, int precision, int scale) { BigDecimal dec = convertToBigDecimal(value); diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/RexImpTable.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/RexImpTable.java index 48db266aad..334446bbea 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/RexImpTable.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/RexImpTable.java @@ -1017,6 +1017,7 @@ public class RexImpTable { defineMethod(ROUND, IgniteMethod.ROUND.method(), NullPolicy.STRICT); defineMethod(TRUNCATE, IgniteMethod.TRUNCATE.method(), NullPolicy.STRICT); defineMethod(IgniteSqlOperatorTable.SUBSTRING, IgniteMethod.SUBSTRING.method(), NullPolicy.STRICT); + defineMethod(IgniteSqlOperatorTable.DECIMAL_DIVIDE, IgniteMethod.DECIMAL_DIVIDE.method(), NullPolicy.ARG0); map.put(TYPEOF, systemFunctionImplementor); map.put(SYSTEM_RANGE, systemFunctionImplementor); diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates.java index 95b060088d..ec8c2819e4 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates.java @@ -20,6 +20,7 @@ package org.apache.ignite.internal.sql.engine.rel.agg; import static org.apache.ignite.internal.lang.IgniteStringFormatter.format; import com.google.common.collect.ImmutableList; +import java.math.BigDecimal; import java.util.AbstractMap.SimpleEntry; import java.util.ArrayList; import java.util.List; @@ -34,19 +35,24 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeFactory.Builder; import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rel.type.RelDataTypeSystem; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.type.SqlTypeUtil; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.mapping.Mapping; import org.apache.calcite.util.mapping.Mappings; import org.apache.ignite.internal.sql.engine.rel.IgniteProject; import org.apache.ignite.internal.sql.engine.rel.IgniteRel; +import org.apache.ignite.internal.sql.engine.sql.fun.IgniteSqlOperatorTable; import org.apache.ignite.internal.sql.engine.type.IgniteTypeFactory; import org.apache.ignite.internal.sql.engine.util.Commons; +import org.jetbrains.annotations.TestOnly; /** * Map/reduce aggregate utility methods. @@ -62,6 +68,7 @@ public class MapReduceAggregates { "EVERY", "SOME", "ANY", + "AVG", "SINGLE_VALUE", "ANY_VALUE" ); @@ -92,6 +99,10 @@ public class MapReduceAggregates { /** * Creates a physical operator that implements the given logical aggregate as MAP/REDUCE. * + * <p>Final expression consists of an aggregate for map phase (MapNode) and another aggregate for reduce phase (ReduceNode). + * Depending on a accumulator function there can be an intermediate projection between MAP and REDUCE, and a projection after REDUCE + * as well. + * * @param agg Logical aggregate expression. * @param builder Builder to create implementations of MAP and REDUCE phases. * @param fieldMappingOnReduce Mapping to be applied to group sets on REDUCE phase. @@ -125,14 +136,25 @@ public class MapReduceAggregates { // groupSet includes all columns from GROUP BY/GROUPING SETS clauses. int argumentOffset = agg.getGroupSet().cardinality(); + // MAP PHASE AGGREGATE + List<AggregateCall> mapAggCalls = new ArrayList<>(agg.getAggCallList().size()); for (AggregateCall call : agg.getAggCallList()) { - MapReduceAgg mapReduceAgg = createMapReduceAggCall(call, argumentOffset); - argumentOffset += 1; + // See ReturnTypes::AVG_AGG_FUNCTION, Result type of a aggregate with no grouping or with filtering can be nullable. + boolean canBeNull = agg.getGroupCount() == 0 || call.hasFilter(); + + MapReduceAgg mapReduceAgg = createMapReduceAggCall( + Commons.cluster(), + call, + argumentOffset, + agg.getInput().getRowType(), + canBeNull + ); + argumentOffset += mapReduceAgg.reduceCalls.size(); mapReduceAggs.add(mapReduceAgg); - mapAggCalls.add(mapReduceAgg.mapCall); + mapAggCalls.addAll(mapReduceAgg.mapCalls); } // MAP phase should have no less than the number of arguments as original aggregate. @@ -148,7 +170,10 @@ public class MapReduceAggregates { mapAggCalls ); - List<RelDataTypeField> outputRowFields = agg.getRowType().getFieldList(); + // + // REDUCE INPUT PROJECTION + // + RelDataTypeFactory.Builder reduceType = new Builder(Commons.typeFactory()); int groupByColumns = agg.getGroupSet().cardinality(); @@ -158,26 +183,84 @@ public class MapReduceAggregates { // It consists of columns from agg.groupSet and aggregate expressions. for (int i = 0; i < groupByColumns; i++) { + List<RelDataTypeField> outputRowFields = agg.getRowType().getFieldList(); RelDataType type = outputRowFields.get(i).getType(); reduceType.add("f" + reduceType.getFieldCount(), type); } + RexBuilder rexBuilder = agg.getCluster().getRexBuilder(); + IgniteTypeFactory typeFactory = (IgniteTypeFactory) agg.getCluster().getTypeFactory(); + + List<RexNode> reduceInputExprs = new ArrayList<>(); + + for (int i = 0; i < map.getRowType().getFieldList().size(); i++) { + RelDataType type = map.getRowType().getFieldList().get(i).getType(); + RexInputRef ref = new RexInputRef(i, type); + reduceInputExprs.add(ref); + } + + // Build a list of projections for reduce operator, + // if all projections are identity, it is not necessary + // to create a projection between MAP and REDUCE operators. + + boolean additionalProjectionsForReduce = false; + + for (int i = 0, argOffset = 0; i < mapReduceAggs.size(); i++) { + MapReduceAgg mapReduceAgg = mapReduceAggs.get(i); + int argIdx = groupByColumns + argOffset; + + for (int j = 0; j < mapReduceAgg.reduceCalls.size(); j++) { + RexNode projExpr = mapReduceAgg.makeReduceInputExpr.makeExpr(rexBuilder, map, List.of(argIdx), typeFactory); + reduceInputExprs.set(argIdx, projExpr); + + if (mapReduceAgg.makeReduceInputExpr != USE_INPUT_FIELD) { + additionalProjectionsForReduce = true; + } + + argIdx += 1; + } + + argOffset += mapReduceAgg.reduceCalls.size(); + } + + RelNode reduceInputNode; + if (additionalProjectionsForReduce) { + RelDataTypeFactory.Builder projectRow = new Builder(agg.getCluster().getTypeFactory()); + + for (int i = 0; i < reduceInputExprs.size(); i++) { + RexNode rexNode = reduceInputExprs.get(i); + projectRow.add(String.valueOf(i), rexNode.getType()); + } + + RelDataType projectRowType = projectRow.build(); + + reduceInputNode = builder.makeProject(agg.getCluster(), map, reduceInputExprs, projectRowType); + } else { + reduceInputNode = map; + } + + // + // REDUCE PHASE AGGREGATE + // // Build a list of aggregate calls for REDUCE phase. - // Build a list of projection that accept reduce phase and combine/collect/cast results. + // Build a list of projections (arg-list, expr) that accept reduce phase and combine/collect/cast results. List<AggregateCall> reduceAggCalls = new ArrayList<>(); - List<Map.Entry<List<Integer>, MakeReduceExpr>> projection = new ArrayList<>(); + List<Map.Entry<List<Integer>, MakeReduceExpr>> projection = new ArrayList<>(mapReduceAggs.size()); for (MapReduceAgg mapReduceAgg : mapReduceAggs) { // Update row type returned by REDUCE node. - AggregateCall reduceCall = mapReduceAgg.reduceCall; - reduceType.add("f" + reduceType.getFieldCount(), reduceCall.getType()); - reduceAggCalls.add(reduceCall); + int i = 0; + for (AggregateCall reduceCall : mapReduceAgg.reduceCalls) { + reduceType.add("f" + i + "_" + reduceType.getFieldCount(), reduceCall.getType()); + reduceAggCalls.add(reduceCall); + i += 1; + } // Update projection list - List<Integer> argList = mapReduceAgg.argList; - MakeReduceExpr projectionExpr = mapReduceAgg.makeReduceExpr; - projection.add(new SimpleEntry<>(argList, projectionExpr)); + List<Integer> reduceArgList = mapReduceAgg.argList; + MakeReduceExpr projectionExpr = mapReduceAgg.makeReduceOutputExpr; + projection.add(new SimpleEntry<>(reduceArgList, projectionExpr)); if (projectionExpr != USE_INPUT_FIELD) { sameAggsForBothPhases = false; @@ -192,7 +275,7 @@ public class MapReduceAggregates { } // if the number of aggregates on MAP phase is larger then the number of aggregates on REDUCE phase, - // then some of MAP aggregates are not used by REDUCE phase and this is a bug. + // assume that some of MAP aggregates are not used by REDUCE phase and this is a bug. // // NOTE: In general case REDUCE phase can use more aggregates than MAP phase, // but at the moment there is no support for such aggregates. @@ -207,20 +290,23 @@ public class MapReduceAggregates { IgniteRel reduce = builder.makeReduceAgg( agg.getCluster(), - map, + reduceInputNode, groupSetOnReduce, groupSetsOnReduce, reduceAggCalls, reduceTypeToUse ); + // + // FINAL PROJECTION + // // if aggregate MAP phase uses the same aggregates as REDUCE phase, // there is no need to add a projection because no additional actions are required to compute final results. if (sameAggsForBothPhases) { return reduce; } - List<RexNode> projectionList = new ArrayList<>(projection.size()); + List<RexNode> projectionList = new ArrayList<>(projection.size() + groupByColumns); // Projection list returned by AggregateNode consists of columns from GROUP BY clause // and expressions that represent aggregate calls. @@ -228,17 +314,24 @@ public class MapReduceAggregates { int i = 0; for (; i < groupByColumns; i++) { + List<RelDataTypeField> outputRowFields = agg.getRowType().getFieldList(); RelDataType type = outputRowFields.get(i).getType(); RexInputRef ref = new RexInputRef(i, type); projectionList.add(ref); } - RexBuilder rexBuilder = agg.getCluster().getRexBuilder(); - IgniteTypeFactory typeFactory = (IgniteTypeFactory) agg.getCluster().getTypeFactory(); - for (Map.Entry<List<Integer>, MakeReduceExpr> expr : projection) { RexNode resultExpr = expr.getValue().makeExpr(rexBuilder, reduce, expr.getKey(), typeFactory); projectionList.add(resultExpr); + } + + assert projectionList.size() == agg.getRowType().getFieldList().size() : + format("Projection size does not match. Expected: {} but got {}", + agg.getRowType().getFieldList().size(), projectionList.size()); + + for (i = 0; i < projectionList.size(); i++) { + RexNode resultExpr = projectionList.get(i); + List<RelDataTypeField> outputRowFields = agg.getRowType().getFieldList(); // Put assertion here so we can see an expression that caused a type mismatch, // since Project::isValid only shows types. @@ -246,7 +339,6 @@ public class MapReduceAggregates { format("Type at position#{} does not match. Expected: {} but got {}.\nREDUCE aggregates: {}\nRow: {}.\nExpr: {}", i, resultExpr.getType(), outputRowFields.get(i).getType(), reduceAggCalls, outputRowFields, resultExpr); - i++; } return new IgniteProject(agg.getCluster(), reduce.getTraitSet(), reduce, projectionList, agg.getRowType()); @@ -255,15 +347,24 @@ public class MapReduceAggregates { /** * Creates a MAP/REDUCE details for this call. */ - public static MapReduceAgg createMapReduceAggCall(AggregateCall call, int reduceArgumentOffset) { + public static MapReduceAgg createMapReduceAggCall( + RelOptCluster cluster, + AggregateCall call, + int reduceArgumentOffset, + RelDataType input, + boolean canBeNull + ) { String aggName = call.getAggregation().getName(); assert AGG_SUPPORTING_MAP_REDUCE.contains(aggName) : "Aggregate does not support MAP/REDUCE " + call; - if ("COUNT".equals(aggName)) { - return createCountAgg(call, reduceArgumentOffset); - } else { - return createSimpleAgg(call, reduceArgumentOffset); + switch (aggName) { + case "COUNT": + return createCountAgg(call, reduceArgumentOffset); + case "AVG": + return createAvgAgg(cluster, call, reduceArgumentOffset, input, canBeNull); + default: + return createSimpleAgg(call, reduceArgumentOffset); } } @@ -277,8 +378,13 @@ public class MapReduceAggregates { IgniteRel makeMapAgg(RelOptCluster cluster, RelNode input, ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets, List<AggregateCall> aggregateCalls); + /** + * Creates intermediate projection operator that transforms results from MAP phase, and transforms them to inputs to REDUCE phase. + */ + IgniteRel makeProject(RelOptCluster cluster, RelNode input, List<RexNode> reduceInputExprs, RelDataType projectRowType); + /** Creates a rel node that represents a REDUCE phase.*/ - IgniteRel makeReduceAgg(RelOptCluster cluster, RelNode map, + IgniteRel makeReduceAgg(RelOptCluster cluster, RelNode input, ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets, List<AggregateCall> aggregateCalls, RelDataType outputType); } @@ -286,24 +392,48 @@ public class MapReduceAggregates { /** Contains information on how to build MAP/REDUCE version of an aggregate. */ public static class MapReduceAgg { + /** Argument list on reduce phase. */ final List<Integer> argList; - final AggregateCall mapCall; + /** MAP phase aggregate, an initial aggregation function was transformed into. */ + final List<AggregateCall> mapCalls; + + /** REDUCE phase aggregate, an initial aggregation function was transformed into. */ + final List<AggregateCall> reduceCalls; - final AggregateCall reduceCall; + /** Produces expressions to consume results of a MAP phase aggregation. */ + final MakeReduceExpr makeReduceInputExpr; - final MakeReduceExpr makeReduceExpr; + /** Produces expressions to consume results of a REDUCE phase aggregation to comprise final result. */ + final MakeReduceExpr makeReduceOutputExpr; - MapReduceAgg(List<Integer> argList, AggregateCall mapCall, AggregateCall reduceCall, MakeReduceExpr makeReduceExpr) { + MapReduceAgg( + List<Integer> argList, + AggregateCall mapCalls, + AggregateCall reduceCalls, + MakeReduceExpr makeReduceOutputExpr + ) { + this(argList, List.of(mapCalls), USE_INPUT_FIELD, List.of(reduceCalls), makeReduceOutputExpr); + } + + MapReduceAgg( + List<Integer> argList, + List<AggregateCall> mapCalls, + MakeReduceExpr makeReduceInputExpr, + List<AggregateCall> reduceCalls, + MakeReduceExpr makeReduceOutputExpr + ) { this.argList = argList; - this.mapCall = mapCall; - this.reduceCall = reduceCall; - this.makeReduceExpr = makeReduceExpr; + this.mapCalls = mapCalls; + this.reduceCalls = reduceCalls; + this.makeReduceInputExpr = makeReduceInputExpr; + this.makeReduceOutputExpr = makeReduceOutputExpr; } /** A call for REDUCE phase. */ + @TestOnly public AggregateCall getReduceCall() { - return reduceCall; + return reduceCalls.get(0); } } @@ -322,13 +452,13 @@ public class MapReduceAggregates { null, call.collation, call.type, - null); + "COUNT_" + reduceArgumentOffset + "_MAP_SUM"); // COUNT(x) aggregate have type BIGINT, but the type of SUM(COUNT(x)) is DECIMAL, // so we should convert it to back to BIGINT. MakeReduceExpr exprBuilder = (rexBuilder, input, args, typeFactory) -> { RexInputRef ref = rexBuilder.makeInputRef(input, args.get(0)); - return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.BIGINT), ref); + return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.BIGINT), ref, true, false); }; return new MapReduceAgg(argList, call, sum0, exprBuilder); @@ -356,10 +486,189 @@ public class MapReduceAggregates { return new MapReduceAgg(argList, call, reduceCall, USE_INPUT_FIELD); } + /** + * Produces intermediate expressions that modify results of MAP/REDUCE aggregate. + * For example: after splitting a function into a MAP aggregate and REDUCE aggregate it is necessary to add casts to + * output of a REDUCE phase aggregate. + * + * <p>In order to avoid creating unnecessary projections, use {@link MapReduceAggregates#USE_INPUT_FIELD}. + */ @FunctionalInterface private interface MakeReduceExpr { - /** Creates an expression that produces result of REDUCE phase of an aggregate. */ + /** + * Creates an expression that applies a performs computation (e.g. applies some function, adds a cast) + * on {@code args} fields of input relation. + * + * @param rexBuilder Expression builder. + * @param input Input relation. + * @param args Arguments. + * @param typeFactory Type factory. + * + * @return Expression. + */ RexNode makeExpr(RexBuilder rexBuilder, RelNode input, List<Integer> args, IgniteTypeFactory typeFactory); } + + private static MapReduceAgg createAvgAgg( + RelOptCluster cluster, + AggregateCall call, + int reduceArgumentOffset, + RelDataType inputType, + boolean canBeNull + ) { + RelDataTypeFactory tf = cluster.getTypeFactory(); + RelDataTypeSystem typeSystem = tf.getTypeSystem(); + + RelDataType fieldType = inputType.getFieldList().get(call.getArgList().get(0)).getType(); + + // In case of AVG(NULL) return a simple version of an aggregate, because result is always NULL. + if (fieldType.getSqlTypeName() == SqlTypeName.NULL) { + return createSimpleAgg(call, reduceArgumentOffset); + } + + // AVG(x) : SUM(x)/COUNT0(x) + // MAP : SUM(x) / COUNT(x) + + // SUM(x) as s + RelDataType mapSumType = typeSystem.deriveSumType(tf, fieldType); + if (canBeNull) { + mapSumType = tf.createTypeWithNullability(mapSumType, true); + } + + AggregateCall mapSum0 = AggregateCall.create( + SqlStdOperatorTable.SUM, + call.isDistinct(), + call.isApproximate(), + call.ignoreNulls(), + ImmutableList.of(), + call.getArgList(), + call.filterArg, + null, + call.collation, + mapSumType, + "AVG_SUM" + reduceArgumentOffset); + + // COUNT(x) as c + RelDataType mapCountType = tf.createSqlType(SqlTypeName.BIGINT); + + AggregateCall mapCount0 = AggregateCall.create( + SqlStdOperatorTable.COUNT, + call.isDistinct(), + call.isApproximate(), + call.ignoreNulls(), + ImmutableList.of(), + call.getArgList(), + call.filterArg, + null, + call.collation, + mapCountType, + "AVG_COUNT" + reduceArgumentOffset); + + // REDUCE : SUM(s) as reduce_sum, SUM0(c) as reduce_count + List<Integer> reduceSumArgs = List.of(reduceArgumentOffset); + + // SUM0(s) + RelDataType reduceSumType = typeSystem.deriveSumType(tf, mapSumType); + if (canBeNull) { + reduceSumType = tf.createTypeWithNullability(reduceSumType, true); + } + + AggregateCall reduceSum0 = AggregateCall.create( + SqlStdOperatorTable.SUM, + call.isDistinct(), + call.isApproximate(), + call.ignoreNulls(), + ImmutableList.of(), + reduceSumArgs, + // there is no filtering on REDUCE phase + -1, + null, + call.collation, + reduceSumType, + "AVG_SUM" + reduceArgumentOffset); + + + // SUM0(c) + RelDataType reduceSumCountType = typeSystem.deriveSumType(tf, mapCount0.type); + List<Integer> reduceSumCountArgs = List.of(reduceArgumentOffset + 1); + + AggregateCall reduceSumCount = AggregateCall.create( + SqlStdOperatorTable.SUM0, + call.isDistinct(), + call.isApproximate(), + call.ignoreNulls(), + ImmutableList.of(), + reduceSumCountArgs, + // there is no filtering on REDUCE phase + -1, + null, + call.collation, + reduceSumCountType, + "AVG_SUM0" + reduceArgumentOffset); + + RelDataType finalReduceSumType = reduceSumType; + + MakeReduceExpr reduceInputExpr = (rexBuilder, input, args, typeFactory) -> { + RexInputRef argExpr = rexBuilder.makeInputRef(input, args.get(0)); + + if (args.get(0) == reduceArgumentOffset) { + // Accumulator functions handle NULL, so it is safe to ignore it. + if (!SqlTypeUtil.equalSansNullability(finalReduceSumType, argExpr.getType())) { + return rexBuilder.makeCast(finalReduceSumType, argExpr, true, false); + } else { + return argExpr; + } + } else { + return rexBuilder.makeCast(reduceSumCount.type, argExpr, true, false); + } + + }; + + // PROJECT: reduce_sum/reduce_count + + MakeReduceExpr reduceOutputExpr = (rexBuilder, input, args, typeFactory) -> { + RexNode numeratorRef = rexBuilder.makeInputRef(input, args.get(0)); + RexInputRef denominatorRef = rexBuilder.makeInputRef(input, args.get(1)); + + RelDataType avgType = typeFactory.createTypeWithNullability(mapSum0.type, numeratorRef.getType().isNullable()); + numeratorRef = rexBuilder.ensureType(avgType, numeratorRef, true); + + RexNode sumDivCnt; + if (call.getType().getSqlTypeName() == SqlTypeName.DECIMAL) { + // Return correct decimal type with correct scale and precision. + int precision = call.getType().getPrecision(); + int scale = call.getType().getScale(); + + RexLiteral p = rexBuilder.makeExactLiteral(BigDecimal.valueOf(precision), tf.createSqlType(SqlTypeName.INTEGER)); + RexLiteral s = rexBuilder.makeExactLiteral(BigDecimal.valueOf(scale), tf.createSqlType(SqlTypeName.INTEGER)); + + sumDivCnt = rexBuilder.makeCall(IgniteSqlOperatorTable.DECIMAL_DIVIDE, numeratorRef, denominatorRef, p, s); + } else { + RexNode divideRef = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef); + sumDivCnt = rexBuilder.makeCast(call.getType(), divideRef, true, false); + } + + if (canBeNull) { + // CASE cnt == 0 THEN null + // OTHERWISE sum / cnt + RexLiteral zero = rexBuilder.makeExactLiteral(BigDecimal.ZERO, denominatorRef.getType()); + RexNode eqZero = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, numeratorRef, zero); + RexLiteral nullRes = rexBuilder.makeNullLiteral(call.getType()); + + return rexBuilder.makeCall(SqlStdOperatorTable.CASE, eqZero, nullRes, sumDivCnt); + } else { + return sumDivCnt; + } + }; + + List<Integer> argList = List.of(reduceArgumentOffset, reduceArgumentOffset + 1); + return new MapReduceAgg( + argList, + List.of(mapSum0, mapCount0), + reduceInputExpr, + List.of(reduceSum0, reduceSumCount), + reduceOutputExpr + ); + } } diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/HashAggregateConverterRule.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/HashAggregateConverterRule.java index b14010d030..15d72f5f7f 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/HashAggregateConverterRule.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/HashAggregateConverterRule.java @@ -31,9 +31,11 @@ import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.mapping.Mapping; import org.apache.ignite.internal.sql.engine.rel.IgniteConvention; +import org.apache.ignite.internal.sql.engine.rel.IgniteProject; import org.apache.ignite.internal.sql.engine.rel.IgniteRel; import org.apache.ignite.internal.sql.engine.rel.agg.IgniteColocatedHashAggregate; import org.apache.ignite.internal.sql.engine.rel.agg.IgniteMapHashAggregate; @@ -106,7 +108,10 @@ public class HashAggregateConverterRule { RelTraitSet outTrait = cluster.traitSetOf(IgniteConvention.INSTANCE); Mapping fieldMappingOnReduce = Commons.trimmingMapping(agg.getGroupSet().length(), agg.getGroupSet()); + RelTraitSet reducePhaseTraits = outTrait.replace(IgniteDistributions.single()); + AggregateRelBuilder relBuilder = new AggregateRelBuilder() { + @Override public IgniteRel makeMapAgg(RelOptCluster cluster, RelNode input, ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets, List<AggregateCall> aggregateCalls) { @@ -121,13 +126,26 @@ public class HashAggregateConverterRule { } @Override - public IgniteRel makeReduceAgg(RelOptCluster cluster, RelNode map, ImmutableBitSet groupSet, + public IgniteRel makeProject(RelOptCluster cluster, RelNode input, List<RexNode> reduceInputExprs, + RelDataType projectRowType) { + + return new IgniteProject( + agg.getCluster(), + reducePhaseTraits, + convert(input, inTrait.replace(IgniteDistributions.single())), + reduceInputExprs, + projectRowType + ); + } + + @Override + public IgniteRel makeReduceAgg(RelOptCluster cluster, RelNode input, ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets, List<AggregateCall> aggregateCalls, RelDataType outputType) { return new IgniteReduceHashAggregate( cluster, - outTrait.replace(IgniteDistributions.single()), - convert(map, inTrait.replace(IgniteDistributions.single())), + reducePhaseTraits, + convert(input, inTrait.replace(IgniteDistributions.single())), groupSet, groupSets, aggregateCalls, diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/SortAggregateConverterRule.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/SortAggregateConverterRule.java index 6dd3c82009..5b90131995 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/SortAggregateConverterRule.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/SortAggregateConverterRule.java @@ -34,9 +34,11 @@ import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.mapping.Mapping; import org.apache.ignite.internal.sql.engine.rel.IgniteConvention; +import org.apache.ignite.internal.sql.engine.rel.IgniteProject; import org.apache.ignite.internal.sql.engine.rel.IgniteRel; import org.apache.ignite.internal.sql.engine.rel.agg.IgniteColocatedSortAggregate; import org.apache.ignite.internal.sql.engine.rel.agg.IgniteMapSortAggregate; @@ -135,6 +137,7 @@ public class SortAggregateConverterRule { RelTraitSet outTraits = cluster.traitSetOf(IgniteConvention.INSTANCE).replace(outputCollation); AggregateRelBuilder relBuilder = new AggregateRelBuilder() { + @Override public IgniteRel makeMapAgg(RelOptCluster cluster, RelNode input, ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets, List<AggregateCall> aggregateCalls) { @@ -151,13 +154,25 @@ public class SortAggregateConverterRule { } @Override - public IgniteRel makeReduceAgg(RelOptCluster cluster, RelNode map, ImmutableBitSet groupSet, + public IgniteRel makeProject(RelOptCluster cluster, RelNode input, List<RexNode> reduceInputExprs, + RelDataType projectRowType) { + + return new IgniteProject(agg.getCluster(), + outTraits.replace(IgniteDistributions.single()), + convert(input, outTraits.replace(IgniteDistributions.single())), + reduceInputExprs, + projectRowType + ); + } + + @Override + public IgniteRel makeReduceAgg(RelOptCluster cluster, RelNode input, ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets, List<AggregateCall> aggregateCalls, RelDataType outputType) { return new IgniteReduceSortAggregate( cluster, outTraits.replace(IgniteDistributions.single()), - convert(map, outTraits.replace(IgniteDistributions.single())), + convert(input, outTraits.replace(IgniteDistributions.single())), groupSet, groupSets, aggregateCalls, diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/sql/fun/IgniteSqlOperatorTable.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/sql/fun/IgniteSqlOperatorTable.java index 141bff461a..6d2e733f20 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/sql/fun/IgniteSqlOperatorTable.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/sql/fun/IgniteSqlOperatorTable.java @@ -19,6 +19,7 @@ package org.apache.ignite.internal.sql.engine.sql.fun; import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.sql.SqlBasicFunction; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunctionCategory; @@ -190,6 +191,31 @@ public class IgniteSqlOperatorTable extends ReflectiveSqlOperatorTable { OperandTypes.CHARACTER.or(OperandTypes.BINARY), SqlFunctionCategory.NUMERIC); + /** + * Division operator used by REDUCE phase of AVG aggregate. + * Uses provided values of {@code scale} and {@code precision} to return inferred type. + */ + public static final SqlFunction DECIMAL_DIVIDE = SqlBasicFunction.create("DECIMAL_DIVIDE", + new SqlReturnTypeInference() { + @Override + public @Nullable RelDataType inferReturnType(SqlOperatorBinding opBinding) { + RelDataType arg0 = opBinding.getOperandType(0); + Integer precision = opBinding.getOperandLiteralValue(2, Integer.class); + Integer scale = opBinding.getOperandLiteralValue(3, Integer.class); + + assert precision != null : "precision is not specified: " + opBinding.getOperator(); + assert scale != null : "scale is not specified: " + opBinding.getOperator(); + + boolean nullable = arg0.isNullable(); + RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); + + RelDataType returnType = typeFactory.createSqlType(SqlTypeName.DECIMAL, precision, scale); + return typeFactory.createTypeWithNullability(returnType, nullable); + } + }, + OperandTypes.DIVISION_OPERATOR, + SqlFunctionCategory.NUMERIC); + /** Singleton instance. */ public static final IgniteSqlOperatorTable INSTANCE = new IgniteSqlOperatorTable(); @@ -235,6 +261,7 @@ public class IgniteSqlOperatorTable extends ReflectiveSqlOperatorTable { register(SqlStdOperatorTable.SUM); register(SqlStdOperatorTable.SUM0); register(SqlStdOperatorTable.AVG); + register(DECIMAL_DIVIDE); register(SqlStdOperatorTable.MIN); register(SqlStdOperatorTable.MAX); register(SqlStdOperatorTable.ANY_VALUE); diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/IgniteMethod.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/IgniteMethod.java index 5d419abbf7..0ac553b9ec 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/IgniteMethod.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/IgniteMethod.java @@ -21,6 +21,7 @@ import static org.apache.ignite.internal.lang.IgniteStringFormatter.format; import java.lang.reflect.Method; import java.lang.reflect.Type; +import java.math.BigDecimal; import java.util.Arrays; import java.util.Objects; import java.util.TimeZone; @@ -125,7 +126,14 @@ public enum IgniteMethod { */ TRUNCATE(IgniteSqlFunctions.class, "struncate", true), - SUBSTRING(IgniteSqlFunctions.class, "substring", true); + SUBSTRING(IgniteSqlFunctions.class, "substring", true), + + /** + * Division operator used by REDUCE phase of AVG aggregate. + * See {@link IgniteSqlFunctions#decimalDivide(BigDecimal, BigDecimal, int, int)}. + */ + DECIMAL_DIVIDE(IgniteSqlFunctions.class, "decimalDivide", BigDecimal.class, BigDecimal.class, int.class, int.class), + ; private final Method method; diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/PlanUtils.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/PlanUtils.java index 24e5e49602..90fe27e614 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/PlanUtils.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/PlanUtils.java @@ -127,9 +127,17 @@ public class PlanUtils { for (int i = 0; i < aggregateCalls.size(); i++) { AggregateCall call = aggregateCalls.get(i); - Accumulator acc = accumulators.accumulatorFactory(call).get(); - RelDataType fieldType = acc.returnType(typeFactory); + RelDataType fieldType; + // For a decimal type Accumulator::returnType returns a type with default precision and scale, + // that can cause precision loss when a tuple is sent over the wire by an exchanger/outbox. + // Outbox uses its input type as wire format, so if a scale is 0, then the scale is lost + // (see Outbox::sendBatch -> RowHandler::toBinaryTuple -> BinaryTupleBuilder::appendDecimalNotNull). + if (call.getType().getSqlTypeName().allowsScale()) { + fieldType = call.type; + } else { + fieldType = acc.returnType(typeFactory); + } String fieldName = "_ACC" + i; builder.add(fieldName, fieldType); diff --git a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctionsTest.java b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctionsTest.java index 5b2b77cb30..813ffb9532 100644 --- a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctionsTest.java +++ b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctionsTest.java @@ -34,6 +34,7 @@ import java.util.function.Supplier; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.ignite.internal.sql.engine.type.IgniteTypeSystem; import org.apache.ignite.lang.ErrorGroups.Sql; +import org.jetbrains.annotations.Nullable; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; @@ -516,4 +517,26 @@ public class IgniteSqlFunctionsTest { assertEquals(Instant.ofEpochMilli(expMillis), Instant.ofEpochMilli(actualTs)); } + + @ParameterizedTest + @CsvSource( + value = { + "1; 2; 0.5", + "1; 3; 0.3333333333333333", + "6; 2; 3", + "1; 0;", + }, + delimiterString = ";" + ) + public void testAvgDivide(String a, String b, @Nullable String expected) { + BigDecimal num = new BigDecimal(a); + BigDecimal denum = new BigDecimal(b); + + if (expected != null) { + BigDecimal actual = IgniteSqlFunctions.decimalDivide(num, denum, 4, 2); + assertEquals(new BigDecimal(expected), actual); + } else { + assertThrows(ArithmeticException.class, () -> IgniteSqlFunctions.decimalDivide(num, denum, 4, 2)); + } + } } diff --git a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/HashAggregateExecutionTest.java b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/HashAggregateExecutionTest.java index daaa6e7f25..bb4355ca79 100644 --- a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/HashAggregateExecutionTest.java +++ b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/HashAggregateExecutionTest.java @@ -118,7 +118,13 @@ public class HashAggregateExecutionTest extends BaseAggregateTest { ImmutableBitSet grpSet = grpSets.get(0); Mapping reduceMapping = Commons.trimmingMapping(grpSet.length(), grpSet); - MapReduceAgg mapReduceAgg = MapReduceAggregates.createMapReduceAggCall(call, reduceMapping.getTargetCount()); + MapReduceAgg mapReduceAgg = MapReduceAggregates.createMapReduceAggCall( + Commons.cluster(), + call, + reduceMapping.getTargetCount(), + inRowType, + true + ); HashAggregateNode<Object[]> aggRdc = new HashAggregateNode<>( ctx, diff --git a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/HashAggregateSingleGroupExecutionTest.java b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/HashAggregateSingleGroupExecutionTest.java index 249ca1b956..fdce4dbf2d 100644 --- a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/HashAggregateSingleGroupExecutionTest.java +++ b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/HashAggregateSingleGroupExecutionTest.java @@ -202,7 +202,7 @@ public class HashAggregateSingleGroupExecutionTest extends AbstractExecutionTest map.register(scan); RelDataType hashRowType = PlanUtils.createHashAggRowType(grpSets, tf, rowType, List.of(mapCall)); - MapReduceAgg reduceAggCall = MapReduceAggregates.createMapReduceAggCall(mapCall, 0); + MapReduceAgg reduceAggCall = MapReduceAggregates.createMapReduceAggCall(Commons.cluster(), mapCall, 0, rowType, true); HashAggregateNode<Object[]> reduce = new HashAggregateNode<>(ctx, REDUCE, grpSets, accFactory(ctx, reduceAggCall.getReduceCall(), REDUCE, hashRowType), rowFactory()); @@ -431,7 +431,7 @@ public class HashAggregateSingleGroupExecutionTest extends AbstractExecutionTest IgniteTypeFactory tf = Commons.typeFactory(); RelDataType hashRowType = PlanUtils.createHashAggRowType(grpSets, tf, rowType, List.of(call)); - MapReduceAgg reduceAggCall = MapReduceAggregates.createMapReduceAggCall(call, 0); + MapReduceAgg reduceAggCall = MapReduceAggregates.createMapReduceAggCall(Commons.cluster(), call, 0, rowType, true); return newHashAggNode(ctx, REDUCE, grpSets, hashRowType, reduceAggCall.getReduceCall()); } diff --git a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/SortAggregateExecutionTest.java b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/SortAggregateExecutionTest.java index f40ccda15c..765ac75912 100644 --- a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/SortAggregateExecutionTest.java +++ b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/SortAggregateExecutionTest.java @@ -137,7 +137,13 @@ public class SortAggregateExecutionTest extends BaseAggregateTest { } Mapping mapping = Commons.trimmingMapping(grpSet.length(), grpSet); - MapReduceAgg mapReduceAgg = MapReduceAggregates.createMapReduceAggCall(call, mapping.getTargetCount()); + MapReduceAgg mapReduceAgg = MapReduceAggregates.createMapReduceAggCall( + Commons.cluster(), + call, + mapping.getTargetCount(), + inRowType, + true + ); SortAggregateNode<Object[]> aggRdc = new SortAggregateNode<>( ctx, diff --git a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/AbstractPlannerTest.java b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/AbstractPlannerTest.java index 95ac9c5fda..ec2456f5dd 100644 --- a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/AbstractPlannerTest.java +++ b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/AbstractPlannerTest.java @@ -65,6 +65,7 @@ import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.ColumnStrategy; import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.sql.SqlExplainFormat; import org.apache.calcite.sql.SqlExplainLevel; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlNode; @@ -483,6 +484,9 @@ public abstract class AbstractPlannerTest extends IgniteAbstractTest { ) throws Exception { IgniteRel plan = physicalPlan(sql, schemas, hintStrategies, params, null, disabledRules); + String planString = RelOptUtil.dumpPlan("", plan, SqlExplainFormat.TEXT, SqlExplainLevel.ALL_ATTRIBUTES); + log.info("statement: {}\n{}", sql, planString); + checkSplitAndSerialization(plan, schemas); try { diff --git a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceAggregatesTest.java b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceAggregatesTest.java index 9a4ded58dc..bd33ecfd99 100644 --- a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceAggregatesTest.java +++ b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceAggregatesTest.java @@ -30,6 +30,7 @@ import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.logical.LogicalValues; import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.ImmutableBitSet; @@ -37,6 +38,7 @@ import org.apache.calcite.util.Pair; import org.apache.calcite.util.mapping.Mapping; import org.apache.calcite.util.mapping.MappingType; import org.apache.calcite.util.mapping.Mappings; +import org.apache.ignite.internal.sql.engine.rel.IgniteProject; import org.apache.ignite.internal.sql.engine.rel.IgniteRel; import org.apache.ignite.internal.sql.engine.rel.IgniteValues; import org.apache.ignite.internal.sql.engine.rel.agg.MapReduceAggregates; @@ -117,9 +119,14 @@ public class MapReduceAggregatesTest { return createOutExpr(cluster, input); } + @Override + public IgniteRel makeProject(RelOptCluster cluster, RelNode input, List<RexNode> reduceInputExprs, RelDataType projectRowType) { + return new IgniteProject(cluster, input.getTraitSet(), input, reduceInputExprs, projectRowType); + } + @Override public IgniteRel makeReduceAgg(RelOptCluster cluster, - RelNode map, + RelNode input, ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets, List<AggregateCall> aggregateCalls, @@ -127,7 +134,7 @@ public class MapReduceAggregatesTest { collectedGroupSets.add(Pair.of(groupSet, groupSets)); - return createOutExpr(cluster, map); + return createOutExpr(cluster, input); } private IgniteValues createOutExpr(RelOptCluster cluster, RelNode input) { diff --git a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceHashAggregatePlannerTest.java b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceHashAggregatePlannerTest.java index 7d803c3277..5d1813b1b0 100644 --- a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceHashAggregatePlannerTest.java +++ b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceHashAggregatePlannerTest.java @@ -18,9 +18,6 @@ package org.apache.ignite.internal.sql.engine.planner; import static java.util.function.Predicate.not; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.containsString; -import static org.junit.jupiter.api.Assertions.assertThrows; import java.util.List; import java.util.Objects; @@ -34,7 +31,6 @@ import org.apache.ignite.internal.sql.engine.rel.IgniteExchange; import org.apache.ignite.internal.sql.engine.rel.IgniteLimit; import org.apache.ignite.internal.sql.engine.rel.IgniteMergeJoin; import org.apache.ignite.internal.sql.engine.rel.IgniteProject; -import org.apache.ignite.internal.sql.engine.rel.IgniteRel; import org.apache.ignite.internal.sql.engine.rel.IgniteSort; import org.apache.ignite.internal.sql.engine.rel.agg.IgniteMapHashAggregate; import org.apache.ignite.internal.sql.engine.rel.agg.IgniteReduceHashAggregate; @@ -476,42 +472,42 @@ public class MapReduceHashAggregatePlannerTest extends AbstractAggregatePlannerT assertPlan(TestCase.CASE_22, nonColocated, disableRules); assertPlan(TestCase.CASE_22A, nonColocated, disableRules); - - Predicate<RelNode> colocated = hasChildThat(isInstanceOf(IgniteReduceHashAggregate.class) - .and(in -> hasAggregates(countReduce).test(in.getAggregateCalls())) - .and(input(isInstanceOf(IgniteExchange.class) - .and(input(isInstanceOf(IgniteMapHashAggregate.class) - .and(in -> hasAggregates(countMap).test(in.getAggCallList())) - .and(input(isTableScan("TEST"))) - ) - )) - )); - - assertPlan(TestCase.CASE_22B, colocated, disableRules); - assertPlan(TestCase.CASE_22C, colocated, disableRules); + assertPlan(TestCase.CASE_22B, nonColocated, disableRules); + assertPlan(TestCase.CASE_22C, nonColocated, disableRules); } /** - * Validates that AVG can not be used as two phase mode. - * Should be fixed with TODO https://issues.apache.org/jira/browse/IGNITE-20009 + * Validates that AVG aggregate is split into multiple expressions. */ @Test - public void testAvgAgg() { - RuntimeException e = assertThrows(RuntimeException.class, - () -> assertPlan(TestCase.CASE_23, isInstanceOf(IgniteRel.class), disableRules)); - assertThat(e.getMessage(), containsString("There are not enough rules to produce a node with desired properties")); - - e = assertThrows(RuntimeException.class, - () -> assertPlan(TestCase.CASE_23A, isInstanceOf(IgniteRel.class), disableRules)); - assertThat(e.getMessage(), containsString("There are not enough rules to produce a node with desired properties")); - - e = assertThrows(RuntimeException.class, - () -> assertPlan(TestCase.CASE_23B, isInstanceOf(IgniteRel.class), disableRules)); - assertThat(e.getMessage(), containsString("There are not enough rules to produce a node with desired properties")); - - e = assertThrows(RuntimeException.class, - () -> assertPlan(TestCase.CASE_23C, isInstanceOf(IgniteRel.class), disableRules)); - assertThat(e.getMessage(), containsString("There are not enough rules to produce a node with desired properties")); + public void twoPhaseAvgAgg() throws Exception { + Predicate<AggregateCall> sumMap = (a) -> + Objects.equals(a.getAggregation().getName(), "SUM") && a.getArgList().equals(List.of(1)); + + Predicate<AggregateCall> countMap = (a) -> + Objects.equals(a.getAggregation().getName(), "COUNT") && a.getArgList().equals(List.of(1)); + + Predicate<AggregateCall> sumReduce = (a) -> + Objects.equals(a.getAggregation().getName(), "SUM") && a.getArgList().equals(List.of(1)); + + Predicate<AggregateCall> sum0Reduce = (a) -> + Objects.equals(a.getAggregation().getName(), "$SUM0") && a.getArgList().equals(List.of(2)); + + Predicate<RelNode> nonColocated = hasChildThat(isInstanceOf(IgniteReduceHashAggregate.class) + .and(in -> hasAggregates(sumReduce, sum0Reduce).test(in.getAggregateCalls())) + .and(input(isInstanceOf(IgniteProject.class) + .and(input(isInstanceOf(IgniteExchange.class) + .and(hasDistribution(IgniteDistributions.single())) + .and(input(isInstanceOf(IgniteMapHashAggregate.class) + .and(in -> hasAggregates(sumMap, countMap).test(in.getAggCallList())) + ) + )) + )))); + + assertPlan(TestCase.CASE_23, nonColocated, disableRules); + assertPlan(TestCase.CASE_23A, nonColocated, disableRules); + assertPlan(TestCase.CASE_23B, nonColocated, disableRules); + assertPlan(TestCase.CASE_23C, nonColocated, disableRules); } /** diff --git a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceSortAggregatePlannerTest.java b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceSortAggregatePlannerTest.java index 88543a92ee..5329cccbc0 100644 --- a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceSortAggregatePlannerTest.java +++ b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceSortAggregatePlannerTest.java @@ -19,9 +19,6 @@ package org.apache.ignite.internal.sql.engine.planner; import static java.util.function.Predicate.not; import static org.apache.ignite.internal.sql.engine.trait.IgniteDistributions.single; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.containsString; -import static org.junit.jupiter.api.Assertions.assertThrows; import java.util.List; import java.util.Objects; @@ -36,7 +33,6 @@ import org.apache.ignite.internal.sql.engine.rel.IgniteExchange; import org.apache.ignite.internal.sql.engine.rel.IgniteLimit; import org.apache.ignite.internal.sql.engine.rel.IgniteMergeJoin; import org.apache.ignite.internal.sql.engine.rel.IgniteProject; -import org.apache.ignite.internal.sql.engine.rel.IgniteRel; import org.apache.ignite.internal.sql.engine.rel.IgniteSort; import org.apache.ignite.internal.sql.engine.rel.agg.IgniteMapSortAggregate; import org.apache.ignite.internal.sql.engine.rel.agg.IgniteReduceSortAggregate; @@ -482,25 +478,38 @@ public class MapReduceSortAggregatePlannerTest extends AbstractAggregatePlannerT } /** - * Validates that AVG can not be used as two phase mode. Should be fixed with TODO https://issues.apache.org/jira/browse/IGNITE-20009 + * Validates that AVG aggregate is split into multiple expressions: + * SUM(col) as s, COUNT(col) as c on map phase and then SUM(s)/SUM0(c) on reduce phase. */ @Test - public void testAvgAgg() { - RuntimeException e = assertThrows(RuntimeException.class, - () -> assertPlan(TestCase.CASE_23, isInstanceOf(IgniteRel.class), disableRules)); - assertThat(e.getMessage(), containsString("There are not enough rules to produce a node with desired properties")); - - e = assertThrows(RuntimeException.class, - () -> assertPlan(TestCase.CASE_23A, isInstanceOf(IgniteRel.class), disableRules)); - assertThat(e.getMessage(), containsString("There are not enough rules to produce a node with desired properties")); - - e = assertThrows(RuntimeException.class, - () -> assertPlan(TestCase.CASE_23B, isInstanceOf(IgniteRel.class), disableRules)); - assertThat(e.getMessage(), containsString("There are not enough rules to produce a node with desired properties")); - - e = assertThrows(RuntimeException.class, - () -> assertPlan(TestCase.CASE_23C, isInstanceOf(IgniteRel.class), disableRules)); - assertThat(e.getMessage(), containsString("There are not enough rules to produce a node with desired properties")); + public void twoPhaseAvgAgg() throws Exception { + Predicate<AggregateCall> sumMap = (a) -> + Objects.equals(a.getAggregation().getName(), "SUM") && a.getArgList().equals(List.of(1)); + + Predicate<AggregateCall> countMap = (a) -> + Objects.equals(a.getAggregation().getName(), "COUNT") && a.getArgList().equals(List.of(1)); + + Predicate<AggregateCall> sumReduce = (a) -> + Objects.equals(a.getAggregation().getName(), "SUM") && a.getArgList().equals(List.of(1)); + + Predicate<AggregateCall> sum0Reduce = (a) -> + Objects.equals(a.getAggregation().getName(), "$SUM0") && a.getArgList().equals(List.of(2)); + + Predicate<RelNode> nonColocated = hasChildThat(isInstanceOf(IgniteReduceSortAggregate.class) + .and(in -> hasAggregates(sumReduce, sum0Reduce).test(in.getAggregateCalls())) + .and(input(isInstanceOf(IgniteProject.class) + .and(input(isInstanceOf(IgniteExchange.class) + .and(hasDistribution(single())) + .and(input(isInstanceOf(IgniteMapSortAggregate.class) + .and(in -> hasAggregates(sumMap, countMap).test(in.getAggCallList())) + ) + )) + )))); + + assertPlan(TestCase.CASE_23, nonColocated, disableRules); + assertPlan(TestCase.CASE_23A, nonColocated, disableRules); + assertPlan(TestCase.CASE_23B, nonColocated, disableRules); + assertPlan(TestCase.CASE_23C, nonColocated, disableRules); } /** diff --git a/modules/sql-engine/src/testFixtures/java/org/apache/ignite/internal/sql/engine/util/QueryChecker.java b/modules/sql-engine/src/testFixtures/java/org/apache/ignite/internal/sql/engine/util/QueryChecker.java index 36a46d2009..96a1f6e1ee 100644 --- a/modules/sql-engine/src/testFixtures/java/org/apache/ignite/internal/sql/engine/util/QueryChecker.java +++ b/modules/sql-engine/src/testFixtures/java/org/apache/ignite/internal/sql/engine/util/QueryChecker.java @@ -187,8 +187,8 @@ public interface QueryChecker { static void assertEqualsCollections(Collection<?> exp, Collection<?> act) { if (exp.size() != act.size()) { String errorMsg = new IgniteStringBuilder("Collections sizes are not equal:").nl() - .app("\tExpected: ").app(exp.size()).nl() - .app("\t Actual: ").app(act.size()).toString(); + .app("\tExpected: ").app(exp.size()).app(exp).nl() + .app("\t Actual: ").app(act.size()).app(act).toString(); throw new AssertionError(errorMsg); }