This is an automated email from the ASF dual-hosted git repository.
korlov pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/ignite-3.git
The following commit(s) were added to refs/heads/main by this push:
new 625aad225d IGNITE-20009: Sql. Rework 2-phase aggregates part 2. AVG as
SUM / COUNT. (#3413)
625aad225d is described below
commit 625aad225d2074cbf7e5419936532feb80d12cba
Author: Max Zhuravkov <[email protected]>
AuthorDate: Tue Mar 26 15:56:23 2024 +0200
IGNITE-20009: Sql. Rework 2-phase aggregates part 2. AVG as SUM / COUNT.
(#3413)
---
.../internal/sql/engine/ItAggregatesTest.java | 190 +++++++++-
.../sql/engine/exec/LogicalRelImplementor.java | 5 +-
.../sql/engine/exec/exp/IgniteSqlFunctions.java | 10 +
.../internal/sql/engine/exec/exp/RexImpTable.java | 1 +
.../sql/engine/rel/agg/MapReduceAggregates.java | 381 +++++++++++++++++++--
.../engine/rule/HashAggregateConverterRule.java | 24 +-
.../engine/rule/SortAggregateConverterRule.java | 19 +-
.../sql/engine/sql/fun/IgniteSqlOperatorTable.java | 27 ++
.../internal/sql/engine/util/IgniteMethod.java | 10 +-
.../ignite/internal/sql/engine/util/PlanUtils.java | 12 +-
.../engine/exec/exp/IgniteSqlFunctionsTest.java | 23 ++
.../exec/rel/HashAggregateExecutionTest.java | 8 +-
.../rel/HashAggregateSingleGroupExecutionTest.java | 4 +-
.../exec/rel/SortAggregateExecutionTest.java | 8 +-
.../sql/engine/planner/AbstractPlannerTest.java | 4 +
.../engine/planner/MapReduceAggregatesTest.java | 11 +-
.../planner/MapReduceHashAggregatePlannerTest.java | 66 ++--
.../planner/MapReduceSortAggregatePlannerTest.java | 51 +--
.../internal/sql/engine/util/QueryChecker.java | 4 +-
19 files changed, 733 insertions(+), 125 deletions(-)
diff --git
a/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItAggregatesTest.java
b/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItAggregatesTest.java
index 745f3c8d75..f3e12c117a 100644
---
a/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItAggregatesTest.java
+++
b/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItAggregatesTest.java
@@ -21,10 +21,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
+import java.math.BigDecimal;
+import java.math.MathContext;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
-import java.util.stream.Collectors;
+import java.util.Random;
import java.util.stream.Stream;
import org.apache.ignite.internal.sql.BaseSqlIntegrationTest;
import org.apache.ignite.internal.sql.engine.hint.IgniteHint;
@@ -49,9 +52,6 @@ public class ItAggregatesTest extends BaseSqlIntegrationTest {
private static final List<String> MAP_REDUCE_RULES =
List.of("MapReduceHashAggregateConverterRule",
"MapReduceSortAggregateConverterRule");
- private static final List<String> COLO_RULES =
Arrays.stream(DISABLED_RULES).filter(r -> !MAP_REDUCE_RULES.contains(r))
- .collect(Collectors.toList());
-
private static final int ROWS = 103;
/**
@@ -88,6 +88,25 @@ public class ItAggregatesTest extends BaseSqlIntegrationTest
{
sql("CREATE TABLE test_str_int_real_dec "
+ "(id INTEGER PRIMARY KEY, str_col VARCHAR, int_col INTEGER,
real_col REAL, dec_col DECIMAL)");
+
+ sql("CREATE TABLE IF NOT EXISTS numbers ("
+ + "id INTEGER PRIMARY KEY, "
+ + "tinyint_col TINYINT, "
+ + "smallint_col SMALLINT, "
+ + "int_col INTEGER, "
+ + "bigint_col BIGINT, "
+ + "float_col REAL, "
+ + "double_col DOUBLE, "
+ + "dec2_col DECIMAL(2), "
+ + "dec4_2_col DECIMAL(4,2), "
+ + "dec10_2_col DECIMAL(10,2) "
+ + ")");
+
+ sql("CREATE TABLE IF NOT EXISTS not_null_numbers ("
+ + "id INTEGER PRIMARY KEY, "
+ + "int_col INTEGER NOT NULL, "
+ + "dec4_2_col DECIMAL(4,2) NOT NULL"
+ + ")");
}
@ParameterizedTest
@@ -530,27 +549,166 @@ public class ItAggregatesTest extends
BaseSqlIntegrationTest {
@ParameterizedTest
@MethodSource("provideRules")
- public void testAvgOnEmptyGroup(String[] rules) {
- sql("DELETE FROM test_str_int_real_dec");
+ public void testAvg(String[] rules) {
+ sql("DELETE FROM numbers");
+ sql("INSERT INTO numbers VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1, 1), (2, 2,
2, 2, 2, 2, 2, 2, 2, 2)");
+
+ assertQuery("SELECT "
+ + "AVG(tinyint_col), AVG(smallint_col), AVG(int_col),
AVG(bigint_col), "
+ + "AVG(float_col), AVG(double_col), AVG(dec2_col),
AVG(dec4_2_col) "
+ + "FROM numbers")
+ .disableRules(rules)
+ .returns((byte) 1, (short) 1, 1, 1L, 1.5f, 1.5d, new
BigDecimal("1.5"), new BigDecimal("1.50"))
+ .check();
+
+ sql("DELETE FROM numbers");
+ sql("INSERT INTO numbers (id, dec4_2_col) VALUES (1, 1), (2, 2)");
+
+ assertQuery("SELECT AVG(dec4_2_col) FROM numbers")
+ .disableRules(rules)
+ .returns(new BigDecimal("1.50"))
+ .check();
+
+ sql("DELETE FROM numbers");
+ sql("INSERT INTO numbers (id, dec4_2_col) VALUES (1, 1), (2, 2.3333)");
+
+ assertQuery("SELECT AVG(dec4_2_col) FROM numbers")
+ .disableRules(rules)
+ .returns(new BigDecimal("1.665"))
+ .check();
+
+ sql("DELETE FROM numbers");
+ sql("INSERT INTO numbers (id, int_col, dec4_2_col) VALUES (1, null,
null)");
+
+ assertQuery("SELECT AVG(int_col), AVG(dec4_2_col) FROM numbers")
+ .disableRules(rules)
+ .returns(null, null)
+ .check();
+
+ sql("DELETE FROM numbers");
+ sql("INSERT INTO numbers (id, int_col, dec4_2_col) VALUES (1, 1, 1),
(2, null, null)");
+
+ assertQuery("SELECT AVG(int_col), AVG(dec4_2_col) FROM numbers")
+ .disableRules(rules)
+ .returns(1, new BigDecimal("1.00"))
+ .check();
+ }
+
+ @Test
+ public void testAvgRandom() {
+ long seed = System.nanoTime();
+ Random random = new Random(seed);
+
+ sql("DELETE FROM numbers");
+
+ List<BigDecimal> numbers = new ArrayList<>();
+ log.info("Seed: {}", seed);
+
+ for (int i = 1; i < 20; i++) {
+ int val = random.nextInt(100) + 1;
+ BigDecimal num = BigDecimal.valueOf(val);
+ numbers.add(num);
+
+ String query = "INSERT INTO numbers (id, int_col, dec10_2_col)
VALUES(?, ?, ?)";
+ sql(query, i, num.intValue(), num);
+ }
+
+ BigDecimal avg = numbers.stream()
+ .reduce(new BigDecimal("0.00"), BigDecimal::add)
+ .divide(BigDecimal.valueOf(numbers.size()),
MathContext.DECIMAL64);
+
+ for (String[] rules : makePermutations(DISABLED_RULES)) {
+ assertQuery("SELECT AVG(int_col), AVG(dec10_2_col) FROM numbers")
+ .disableRules(rules)
+ .returns(avg.intValue(), avg)
+ .check();
+ }
+ }
+
+ @ParameterizedTest
+ @MethodSource("provideRules")
+ public void testAvgNullNotNull(String[] rules) {
+ sql("DELETE FROM not_null_numbers");
+ sql("INSERT INTO not_null_numbers (id, int_col, dec4_2_col) VALUES (1,
1, 1), (2, 2, 2)");
+
+ assertQuery("SELECT AVG(int_col), AVG(dec4_2_col) FROM
not_null_numbers")
+ .disableRules(rules)
+ .returns(1, new BigDecimal("1.50"))
+ .check();
+
+ // Return type of an AVG aggregate can never be null.
+ assertQuery("SELECT AVG(int_col) FROM not_null_numbers GROUP BY
int_col")
+ .disableRules(rules)
+ .returns(1)
+ .returns(2)
+ .check();
+
+ assertQuery("SELECT AVG(dec4_2_col) FROM not_null_numbers GROUP BY
dec4_2_col")
+ .disableRules(rules)
+ .returns(new BigDecimal("1.00"))
+ .returns(new BigDecimal("2.00"))
+ .check();
- // TODO https://issues.apache.org/jira/browse/IGNITE-20009
- // Remove after is fixed.
- Assumptions.assumeFalse(Arrays.stream(rules)
- .filter(COLO_RULES::contains).count() == COLO_RULES.size(),
"AVG is disabled for MAP/REDUCE");
+ sql("DELETE FROM numbers");
+ sql("INSERT INTO numbers (id, int_col, dec4_2_col) VALUES (1, 1, 1),
(2, 2, 2)");
- assertQuery("SELECT AVG(int_col) FROM test_str_int_real_dec")
+ assertQuery("SELECT AVG(int_col) FROM numbers GROUP BY int_col")
.disableRules(rules)
- .returns(new Object[]{null})
+ .returns(1)
+ .returns(2)
.check();
- assertQuery("SELECT AVG(real_col) FROM test_str_int_real_dec")
+ assertQuery("SELECT AVG(dec4_2_col) FROM numbers GROUP BY dec4_2_col")
.disableRules(rules)
- .returns(new Object[]{null})
+ .returns(new BigDecimal("1.00"))
+ .returns(new BigDecimal("2.00"))
+ .check();
+ }
+
+ @ParameterizedTest
+ @MethodSource("provideRules")
+ public void testAvgOnEmptyGroup(String[] rules) {
+ sql("DELETE FROM numbers");
+
+ assertQuery("SELECT "
+ + "AVG(tinyint_col), AVG(smallint_col), AVG(int_col),
AVG(bigint_col), "
+ + "AVG(float_col), AVG(double_col), AVG(dec2_col),
AVG(dec4_2_col) "
+ + "FROM numbers")
+ .disableRules(rules)
+ .returns(null, null, null, null, null, null, null, null)
+ .check();
+ }
+
+ @ParameterizedTest
+ @MethodSource("provideRules")
+ public void testAvgFromLiterals(String[] rules) {
+
+ assertQuery("SELECT "
+ + "AVG(tinyint_col), AVG(smallint_col), AVG(int_col),
AVG(bigint_col), "
+ + "AVG(float_col), AVG(double_col), AVG(dec2_col),
AVG(dec4_2_col) "
+ + "FROM (VALUES "
+ + "(1::TINYINT, 1::SMALLINT, 1::INTEGER, 1::BIGINT, 1::REAL,
1::DOUBLE, 1::DECIMAL(2), 1.00::DECIMAL(4,2)), "
+ + "(2::TINYINT, 2::SMALLINT, 2::INTEGER, 2::BIGINT, 2::REAL,
2::DOUBLE, 2::DECIMAL(2), 2.00::DECIMAL(4,2)) "
+ + ") "
+ + "t(tinyint_col, smallint_col, int_col, bigint_col,
float_col, double_col, dec2_col, dec4_2_col)")
+ .disableRules(rules)
+ .returns((byte) 1, (short) 1, 1, 1L, 1.5f, 1.5d, new
BigDecimal("1.5"), new BigDecimal("1.50"))
.check();
- assertQuery("SELECT AVG(dec_col) FROM test_str_int_real_dec")
+ assertQuery("SELECT "
+ + "AVG(1::TINYINT), AVG(2::SMALLINT), AVG(3::INTEGER),
AVG(4::BIGINT), "
+ + "AVG(5::REAL), AVG(6::DOUBLE), AVG(7::DECIMAL(2)),
AVG(8.00::DECIMAL(4,2))")
.disableRules(rules)
- .returns(new Object[]{null})
+ .returns((byte) 1, (short) 2, 3, 4L, 5.0f, 6.0d, new
BigDecimal("7"), new BigDecimal("8.00"))
+ .check();
+
+ assertQuery("SELECT AVG(dec2_col), AVG(dec4_2_col) FROM\n"
+ + "(SELECT \n"
+ + " 1::DECIMAL(2) as dec2_col, 2.00::DECIMAL(4, 2) as
dec4_2_col\n"
+ + " UNION\n"
+ + " SELECT 2::DECIMAL(2) as dec2_col, 3.00::DECIMAL(4,2) as
dec4_2_col\n"
+ + ") as t")
+ .returns(new BigDecimal("1.5"), new BigDecimal("2.50"))
.check();
}
diff --git
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/LogicalRelImplementor.java
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/LogicalRelImplementor.java
index 6b44790bd7..5ebeedd8a5 100644
---
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/LogicalRelImplementor.java
+++
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/LogicalRelImplementor.java
@@ -744,7 +744,10 @@ public class LogicalRelImplementor<RowT> implements
IgniteRelVisitor<Node<RowT>>
RelDataType rowType = rel.getRowType();
Supplier<List<AccumulatorWrapper<RowT>>> accFactory =
expressionFactory.accumulatorsFactory(
- type, rel.getAggregateCalls(), null);
+ type,
+ rel.getAggregateCalls(),
+ rel.getInput().getRowType()
+ );
RowSchema rowSchema =
rowSchemaFromRelTypes(RelOptUtil.getFieldTypeList(rowType));
RowFactory<RowT> rowFactory = ctx.rowHandler().factory(rowSchema);
diff --git
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java
index 5f1f31d5d0..eb4639c45e 100644
---
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java
+++
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java
@@ -25,6 +25,7 @@ import static
org.apache.ignite.lang.ErrorGroups.Sql.RUNTIME_ERR;
import java.math.BigDecimal;
import java.math.BigInteger;
+import java.math.MathContext;
import java.math.RoundingMode;
import java.time.LocalDateTime;
import java.time.LocalTime;
@@ -51,6 +52,7 @@ import org.apache.calcite.schema.Statistic;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.ignite.internal.sql.engine.sql.fun.IgniteSqlOperatorTable;
import org.apache.ignite.internal.sql.engine.type.IgniteTypeSystem;
import org.apache.ignite.internal.sql.engine.util.Commons;
import org.apache.ignite.internal.sql.engine.util.TypeUtils;
@@ -473,6 +475,14 @@ public class IgniteSqlFunctions {
}
}
+ /**
+ * Division function for REDUCE phase of AVG aggregate. Precision and
scale is only used by type inference
+ * (see {@link IgniteSqlOperatorTable#DECIMAL_DIVIDE}, their values are
ignored at runtime.
+ */
+ public static BigDecimal decimalDivide(BigDecimal sum, BigDecimal cnt, int
p, int s) {
+ return sum.divide(cnt, MathContext.DECIMAL64);
+ }
+
private static BigDecimal processValueWithIntegralPart(Number value, int
precision, int scale) {
BigDecimal dec = convertToBigDecimal(value);
diff --git
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/RexImpTable.java
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/RexImpTable.java
index 48db266aad..334446bbea 100644
---
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/RexImpTable.java
+++
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/RexImpTable.java
@@ -1017,6 +1017,7 @@ public class RexImpTable {
defineMethod(ROUND, IgniteMethod.ROUND.method(), NullPolicy.STRICT);
defineMethod(TRUNCATE, IgniteMethod.TRUNCATE.method(),
NullPolicy.STRICT);
defineMethod(IgniteSqlOperatorTable.SUBSTRING,
IgniteMethod.SUBSTRING.method(), NullPolicy.STRICT);
+ defineMethod(IgniteSqlOperatorTable.DECIMAL_DIVIDE,
IgniteMethod.DECIMAL_DIVIDE.method(), NullPolicy.ARG0);
map.put(TYPEOF, systemFunctionImplementor);
map.put(SYSTEM_RANGE, systemFunctionImplementor);
diff --git
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates.java
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates.java
index 95b060088d..ec8c2819e4 100644
---
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates.java
+++
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates.java
@@ -20,6 +20,7 @@ package org.apache.ignite.internal.sql.engine.rel.agg;
import static org.apache.ignite.internal.lang.IgniteStringFormatter.format;
import com.google.common.collect.ImmutableList;
+import java.math.BigDecimal;
import java.util.AbstractMap.SimpleEntry;
import java.util.ArrayList;
import java.util.List;
@@ -34,19 +35,24 @@ import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeFactory.Builder;
import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.ignite.internal.sql.engine.rel.IgniteProject;
import org.apache.ignite.internal.sql.engine.rel.IgniteRel;
+import org.apache.ignite.internal.sql.engine.sql.fun.IgniteSqlOperatorTable;
import org.apache.ignite.internal.sql.engine.type.IgniteTypeFactory;
import org.apache.ignite.internal.sql.engine.util.Commons;
+import org.jetbrains.annotations.TestOnly;
/**
* Map/reduce aggregate utility methods.
@@ -62,6 +68,7 @@ public class MapReduceAggregates {
"EVERY",
"SOME",
"ANY",
+ "AVG",
"SINGLE_VALUE",
"ANY_VALUE"
);
@@ -92,6 +99,10 @@ public class MapReduceAggregates {
/**
* Creates a physical operator that implements the given logical aggregate
as MAP/REDUCE.
*
+ * <p>Final expression consists of an aggregate for map phase (MapNode)
and another aggregate for reduce phase (ReduceNode).
+ * Depending on a accumulator function there can be an intermediate
projection between MAP and REDUCE, and a projection after REDUCE
+ * as well.
+ *
* @param agg Logical aggregate expression.
* @param builder Builder to create implementations of MAP and REDUCE
phases.
* @param fieldMappingOnReduce Mapping to be applied to group sets on
REDUCE phase.
@@ -125,14 +136,25 @@ public class MapReduceAggregates {
// groupSet includes all columns from GROUP BY/GROUPING SETS clauses.
int argumentOffset = agg.getGroupSet().cardinality();
+ // MAP PHASE AGGREGATE
+
List<AggregateCall> mapAggCalls = new
ArrayList<>(agg.getAggCallList().size());
for (AggregateCall call : agg.getAggCallList()) {
- MapReduceAgg mapReduceAgg = createMapReduceAggCall(call,
argumentOffset);
- argumentOffset += 1;
+ // See ReturnTypes::AVG_AGG_FUNCTION, Result type of a aggregate
with no grouping or with filtering can be nullable.
+ boolean canBeNull = agg.getGroupCount() == 0 || call.hasFilter();
+
+ MapReduceAgg mapReduceAgg = createMapReduceAggCall(
+ Commons.cluster(),
+ call,
+ argumentOffset,
+ agg.getInput().getRowType(),
+ canBeNull
+ );
+ argumentOffset += mapReduceAgg.reduceCalls.size();
mapReduceAggs.add(mapReduceAgg);
- mapAggCalls.add(mapReduceAgg.mapCall);
+ mapAggCalls.addAll(mapReduceAgg.mapCalls);
}
// MAP phase should have no less than the number of arguments as
original aggregate.
@@ -148,7 +170,10 @@ public class MapReduceAggregates {
mapAggCalls
);
- List<RelDataTypeField> outputRowFields =
agg.getRowType().getFieldList();
+ //
+ // REDUCE INPUT PROJECTION
+ //
+
RelDataTypeFactory.Builder reduceType = new
Builder(Commons.typeFactory());
int groupByColumns = agg.getGroupSet().cardinality();
@@ -158,26 +183,84 @@ public class MapReduceAggregates {
// It consists of columns from agg.groupSet and aggregate expressions.
for (int i = 0; i < groupByColumns; i++) {
+ List<RelDataTypeField> outputRowFields =
agg.getRowType().getFieldList();
RelDataType type = outputRowFields.get(i).getType();
reduceType.add("f" + reduceType.getFieldCount(), type);
}
+ RexBuilder rexBuilder = agg.getCluster().getRexBuilder();
+ IgniteTypeFactory typeFactory = (IgniteTypeFactory)
agg.getCluster().getTypeFactory();
+
+ List<RexNode> reduceInputExprs = new ArrayList<>();
+
+ for (int i = 0; i < map.getRowType().getFieldList().size(); i++) {
+ RelDataType type =
map.getRowType().getFieldList().get(i).getType();
+ RexInputRef ref = new RexInputRef(i, type);
+ reduceInputExprs.add(ref);
+ }
+
+ // Build a list of projections for reduce operator,
+ // if all projections are identity, it is not necessary
+ // to create a projection between MAP and REDUCE operators.
+
+ boolean additionalProjectionsForReduce = false;
+
+ for (int i = 0, argOffset = 0; i < mapReduceAggs.size(); i++) {
+ MapReduceAgg mapReduceAgg = mapReduceAggs.get(i);
+ int argIdx = groupByColumns + argOffset;
+
+ for (int j = 0; j < mapReduceAgg.reduceCalls.size(); j++) {
+ RexNode projExpr =
mapReduceAgg.makeReduceInputExpr.makeExpr(rexBuilder, map, List.of(argIdx),
typeFactory);
+ reduceInputExprs.set(argIdx, projExpr);
+
+ if (mapReduceAgg.makeReduceInputExpr != USE_INPUT_FIELD) {
+ additionalProjectionsForReduce = true;
+ }
+
+ argIdx += 1;
+ }
+
+ argOffset += mapReduceAgg.reduceCalls.size();
+ }
+
+ RelNode reduceInputNode;
+ if (additionalProjectionsForReduce) {
+ RelDataTypeFactory.Builder projectRow = new
Builder(agg.getCluster().getTypeFactory());
+
+ for (int i = 0; i < reduceInputExprs.size(); i++) {
+ RexNode rexNode = reduceInputExprs.get(i);
+ projectRow.add(String.valueOf(i), rexNode.getType());
+ }
+
+ RelDataType projectRowType = projectRow.build();
+
+ reduceInputNode = builder.makeProject(agg.getCluster(), map,
reduceInputExprs, projectRowType);
+ } else {
+ reduceInputNode = map;
+ }
+
+ //
+ // REDUCE PHASE AGGREGATE
+ //
// Build a list of aggregate calls for REDUCE phase.
- // Build a list of projection that accept reduce phase and
combine/collect/cast results.
+ // Build a list of projections (arg-list, expr) that accept reduce
phase and combine/collect/cast results.
List<AggregateCall> reduceAggCalls = new ArrayList<>();
- List<Map.Entry<List<Integer>, MakeReduceExpr>> projection = new
ArrayList<>();
+ List<Map.Entry<List<Integer>, MakeReduceExpr>> projection = new
ArrayList<>(mapReduceAggs.size());
for (MapReduceAgg mapReduceAgg : mapReduceAggs) {
// Update row type returned by REDUCE node.
- AggregateCall reduceCall = mapReduceAgg.reduceCall;
- reduceType.add("f" + reduceType.getFieldCount(),
reduceCall.getType());
- reduceAggCalls.add(reduceCall);
+ int i = 0;
+ for (AggregateCall reduceCall : mapReduceAgg.reduceCalls) {
+ reduceType.add("f" + i + "_" + reduceType.getFieldCount(),
reduceCall.getType());
+ reduceAggCalls.add(reduceCall);
+ i += 1;
+ }
// Update projection list
- List<Integer> argList = mapReduceAgg.argList;
- MakeReduceExpr projectionExpr = mapReduceAgg.makeReduceExpr;
- projection.add(new SimpleEntry<>(argList, projectionExpr));
+ List<Integer> reduceArgList = mapReduceAgg.argList;
+ MakeReduceExpr projectionExpr = mapReduceAgg.makeReduceOutputExpr;
+ projection.add(new SimpleEntry<>(reduceArgList, projectionExpr));
if (projectionExpr != USE_INPUT_FIELD) {
sameAggsForBothPhases = false;
@@ -192,7 +275,7 @@ public class MapReduceAggregates {
}
// if the number of aggregates on MAP phase is larger then the number
of aggregates on REDUCE phase,
- // then some of MAP aggregates are not used by REDUCE phase and this
is a bug.
+ // assume that some of MAP aggregates are not used by REDUCE phase and
this is a bug.
//
// NOTE: In general case REDUCE phase can use more aggregates than MAP
phase,
// but at the moment there is no support for such aggregates.
@@ -207,20 +290,23 @@ public class MapReduceAggregates {
IgniteRel reduce = builder.makeReduceAgg(
agg.getCluster(),
- map,
+ reduceInputNode,
groupSetOnReduce,
groupSetsOnReduce,
reduceAggCalls,
reduceTypeToUse
);
+ //
+ // FINAL PROJECTION
+ //
// if aggregate MAP phase uses the same aggregates as REDUCE phase,
// there is no need to add a projection because no additional actions
are required to compute final results.
if (sameAggsForBothPhases) {
return reduce;
}
- List<RexNode> projectionList = new ArrayList<>(projection.size());
+ List<RexNode> projectionList = new ArrayList<>(projection.size() +
groupByColumns);
// Projection list returned by AggregateNode consists of columns from
GROUP BY clause
// and expressions that represent aggregate calls.
@@ -228,17 +314,24 @@ public class MapReduceAggregates {
int i = 0;
for (; i < groupByColumns; i++) {
+ List<RelDataTypeField> outputRowFields =
agg.getRowType().getFieldList();
RelDataType type = outputRowFields.get(i).getType();
RexInputRef ref = new RexInputRef(i, type);
projectionList.add(ref);
}
- RexBuilder rexBuilder = agg.getCluster().getRexBuilder();
- IgniteTypeFactory typeFactory = (IgniteTypeFactory)
agg.getCluster().getTypeFactory();
-
for (Map.Entry<List<Integer>, MakeReduceExpr> expr : projection) {
RexNode resultExpr = expr.getValue().makeExpr(rexBuilder, reduce,
expr.getKey(), typeFactory);
projectionList.add(resultExpr);
+ }
+
+ assert projectionList.size() == agg.getRowType().getFieldList().size()
:
+ format("Projection size does not match. Expected: {} but got
{}",
+ agg.getRowType().getFieldList().size(),
projectionList.size());
+
+ for (i = 0; i < projectionList.size(); i++) {
+ RexNode resultExpr = projectionList.get(i);
+ List<RelDataTypeField> outputRowFields =
agg.getRowType().getFieldList();
// Put assertion here so we can see an expression that caused a
type mismatch,
// since Project::isValid only shows types.
@@ -246,7 +339,6 @@ public class MapReduceAggregates {
format("Type at position#{} does not match. Expected: {}
but got {}.\nREDUCE aggregates: {}\nRow: {}.\nExpr: {}",
i, resultExpr.getType(),
outputRowFields.get(i).getType(), reduceAggCalls, outputRowFields, resultExpr);
- i++;
}
return new IgniteProject(agg.getCluster(), reduce.getTraitSet(),
reduce, projectionList, agg.getRowType());
@@ -255,15 +347,24 @@ public class MapReduceAggregates {
/**
* Creates a MAP/REDUCE details for this call.
*/
- public static MapReduceAgg createMapReduceAggCall(AggregateCall call, int
reduceArgumentOffset) {
+ public static MapReduceAgg createMapReduceAggCall(
+ RelOptCluster cluster,
+ AggregateCall call,
+ int reduceArgumentOffset,
+ RelDataType input,
+ boolean canBeNull
+ ) {
String aggName = call.getAggregation().getName();
assert AGG_SUPPORTING_MAP_REDUCE.contains(aggName) : "Aggregate does
not support MAP/REDUCE " + call;
- if ("COUNT".equals(aggName)) {
- return createCountAgg(call, reduceArgumentOffset);
- } else {
- return createSimpleAgg(call, reduceArgumentOffset);
+ switch (aggName) {
+ case "COUNT":
+ return createCountAgg(call, reduceArgumentOffset);
+ case "AVG":
+ return createAvgAgg(cluster, call, reduceArgumentOffset,
input, canBeNull);
+ default:
+ return createSimpleAgg(call, reduceArgumentOffset);
}
}
@@ -277,8 +378,13 @@ public class MapReduceAggregates {
IgniteRel makeMapAgg(RelOptCluster cluster, RelNode input,
ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets,
List<AggregateCall> aggregateCalls);
+ /**
+ * Creates intermediate projection operator that transforms results
from MAP phase, and transforms them to inputs to REDUCE phase.
+ */
+ IgniteRel makeProject(RelOptCluster cluster, RelNode input,
List<RexNode> reduceInputExprs, RelDataType projectRowType);
+
/** Creates a rel node that represents a REDUCE phase.*/
- IgniteRel makeReduceAgg(RelOptCluster cluster, RelNode map,
+ IgniteRel makeReduceAgg(RelOptCluster cluster, RelNode input,
ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets,
List<AggregateCall> aggregateCalls, RelDataType outputType);
}
@@ -286,24 +392,48 @@ public class MapReduceAggregates {
/** Contains information on how to build MAP/REDUCE version of an
aggregate. */
public static class MapReduceAgg {
+ /** Argument list on reduce phase. */
final List<Integer> argList;
- final AggregateCall mapCall;
+ /** MAP phase aggregate, an initial aggregation function was
transformed into. */
+ final List<AggregateCall> mapCalls;
+
+ /** REDUCE phase aggregate, an initial aggregation function was
transformed into. */
+ final List<AggregateCall> reduceCalls;
- final AggregateCall reduceCall;
+ /** Produces expressions to consume results of a MAP phase
aggregation. */
+ final MakeReduceExpr makeReduceInputExpr;
- final MakeReduceExpr makeReduceExpr;
+ /** Produces expressions to consume results of a REDUCE phase
aggregation to comprise final result. */
+ final MakeReduceExpr makeReduceOutputExpr;
- MapReduceAgg(List<Integer> argList, AggregateCall mapCall,
AggregateCall reduceCall, MakeReduceExpr makeReduceExpr) {
+ MapReduceAgg(
+ List<Integer> argList,
+ AggregateCall mapCalls,
+ AggregateCall reduceCalls,
+ MakeReduceExpr makeReduceOutputExpr
+ ) {
+ this(argList, List.of(mapCalls), USE_INPUT_FIELD,
List.of(reduceCalls), makeReduceOutputExpr);
+ }
+
+ MapReduceAgg(
+ List<Integer> argList,
+ List<AggregateCall> mapCalls,
+ MakeReduceExpr makeReduceInputExpr,
+ List<AggregateCall> reduceCalls,
+ MakeReduceExpr makeReduceOutputExpr
+ ) {
this.argList = argList;
- this.mapCall = mapCall;
- this.reduceCall = reduceCall;
- this.makeReduceExpr = makeReduceExpr;
+ this.mapCalls = mapCalls;
+ this.reduceCalls = reduceCalls;
+ this.makeReduceInputExpr = makeReduceInputExpr;
+ this.makeReduceOutputExpr = makeReduceOutputExpr;
}
/** A call for REDUCE phase. */
+ @TestOnly
public AggregateCall getReduceCall() {
- return reduceCall;
+ return reduceCalls.get(0);
}
}
@@ -322,13 +452,13 @@ public class MapReduceAggregates {
null,
call.collation,
call.type,
- null);
+ "COUNT_" + reduceArgumentOffset + "_MAP_SUM");
// COUNT(x) aggregate have type BIGINT, but the type of SUM(COUNT(x))
is DECIMAL,
// so we should convert it to back to BIGINT.
MakeReduceExpr exprBuilder = (rexBuilder, input, args, typeFactory) ->
{
RexInputRef ref = rexBuilder.makeInputRef(input, args.get(0));
- return
rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.BIGINT), ref);
+ return
rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.BIGINT), ref, true,
false);
};
return new MapReduceAgg(argList, call, sum0, exprBuilder);
@@ -356,10 +486,189 @@ public class MapReduceAggregates {
return new MapReduceAgg(argList, call, reduceCall, USE_INPUT_FIELD);
}
+ /**
+ * Produces intermediate expressions that modify results of MAP/REDUCE
aggregate.
+ * For example: after splitting a function into a MAP aggregate and REDUCE
aggregate it is necessary to add casts to
+ * output of a REDUCE phase aggregate.
+ *
+ * <p>In order to avoid creating unnecessary projections, use {@link
MapReduceAggregates#USE_INPUT_FIELD}.
+ */
@FunctionalInterface
private interface MakeReduceExpr {
- /** Creates an expression that produces result of REDUCE phase of an
aggregate. */
+ /**
+ * Creates an expression that applies a performs computation (e.g.
applies some function, adds a cast)
+ * on {@code args} fields of input relation.
+ *
+ * @param rexBuilder Expression builder.
+ * @param input Input relation.
+ * @param args Arguments.
+ * @param typeFactory Type factory.
+ *
+ * @return Expression.
+ */
RexNode makeExpr(RexBuilder rexBuilder, RelNode input, List<Integer>
args, IgniteTypeFactory typeFactory);
}
+
+ private static MapReduceAgg createAvgAgg(
+ RelOptCluster cluster,
+ AggregateCall call,
+ int reduceArgumentOffset,
+ RelDataType inputType,
+ boolean canBeNull
+ ) {
+ RelDataTypeFactory tf = cluster.getTypeFactory();
+ RelDataTypeSystem typeSystem = tf.getTypeSystem();
+
+ RelDataType fieldType =
inputType.getFieldList().get(call.getArgList().get(0)).getType();
+
+ // In case of AVG(NULL) return a simple version of an aggregate,
because result is always NULL.
+ if (fieldType.getSqlTypeName() == SqlTypeName.NULL) {
+ return createSimpleAgg(call, reduceArgumentOffset);
+ }
+
+ // AVG(x) : SUM(x)/COUNT0(x)
+ // MAP : SUM(x) / COUNT(x)
+
+ // SUM(x) as s
+ RelDataType mapSumType = typeSystem.deriveSumType(tf, fieldType);
+ if (canBeNull) {
+ mapSumType = tf.createTypeWithNullability(mapSumType, true);
+ }
+
+ AggregateCall mapSum0 = AggregateCall.create(
+ SqlStdOperatorTable.SUM,
+ call.isDistinct(),
+ call.isApproximate(),
+ call.ignoreNulls(),
+ ImmutableList.of(),
+ call.getArgList(),
+ call.filterArg,
+ null,
+ call.collation,
+ mapSumType,
+ "AVG_SUM" + reduceArgumentOffset);
+
+ // COUNT(x) as c
+ RelDataType mapCountType = tf.createSqlType(SqlTypeName.BIGINT);
+
+ AggregateCall mapCount0 = AggregateCall.create(
+ SqlStdOperatorTable.COUNT,
+ call.isDistinct(),
+ call.isApproximate(),
+ call.ignoreNulls(),
+ ImmutableList.of(),
+ call.getArgList(),
+ call.filterArg,
+ null,
+ call.collation,
+ mapCountType,
+ "AVG_COUNT" + reduceArgumentOffset);
+
+ // REDUCE : SUM(s) as reduce_sum, SUM0(c) as reduce_count
+ List<Integer> reduceSumArgs = List.of(reduceArgumentOffset);
+
+ // SUM0(s)
+ RelDataType reduceSumType = typeSystem.deriveSumType(tf, mapSumType);
+ if (canBeNull) {
+ reduceSumType = tf.createTypeWithNullability(reduceSumType, true);
+ }
+
+ AggregateCall reduceSum0 = AggregateCall.create(
+ SqlStdOperatorTable.SUM,
+ call.isDistinct(),
+ call.isApproximate(),
+ call.ignoreNulls(),
+ ImmutableList.of(),
+ reduceSumArgs,
+ // there is no filtering on REDUCE phase
+ -1,
+ null,
+ call.collation,
+ reduceSumType,
+ "AVG_SUM" + reduceArgumentOffset);
+
+
+ // SUM0(c)
+ RelDataType reduceSumCountType = typeSystem.deriveSumType(tf,
mapCount0.type);
+ List<Integer> reduceSumCountArgs = List.of(reduceArgumentOffset + 1);
+
+ AggregateCall reduceSumCount = AggregateCall.create(
+ SqlStdOperatorTable.SUM0,
+ call.isDistinct(),
+ call.isApproximate(),
+ call.ignoreNulls(),
+ ImmutableList.of(),
+ reduceSumCountArgs,
+ // there is no filtering on REDUCE phase
+ -1,
+ null,
+ call.collation,
+ reduceSumCountType,
+ "AVG_SUM0" + reduceArgumentOffset);
+
+ RelDataType finalReduceSumType = reduceSumType;
+
+ MakeReduceExpr reduceInputExpr = (rexBuilder, input, args,
typeFactory) -> {
+ RexInputRef argExpr = rexBuilder.makeInputRef(input, args.get(0));
+
+ if (args.get(0) == reduceArgumentOffset) {
+ // Accumulator functions handle NULL, so it is safe to ignore
it.
+ if (!SqlTypeUtil.equalSansNullability(finalReduceSumType,
argExpr.getType())) {
+ return rexBuilder.makeCast(finalReduceSumType, argExpr,
true, false);
+ } else {
+ return argExpr;
+ }
+ } else {
+ return rexBuilder.makeCast(reduceSumCount.type, argExpr, true,
false);
+ }
+
+ };
+
+ // PROJECT: reduce_sum/reduce_count
+
+ MakeReduceExpr reduceOutputExpr = (rexBuilder, input, args,
typeFactory) -> {
+ RexNode numeratorRef = rexBuilder.makeInputRef(input, args.get(0));
+ RexInputRef denominatorRef = rexBuilder.makeInputRef(input,
args.get(1));
+
+ RelDataType avgType =
typeFactory.createTypeWithNullability(mapSum0.type,
numeratorRef.getType().isNullable());
+ numeratorRef = rexBuilder.ensureType(avgType, numeratorRef, true);
+
+ RexNode sumDivCnt;
+ if (call.getType().getSqlTypeName() == SqlTypeName.DECIMAL) {
+ // Return correct decimal type with correct scale and
precision.
+ int precision = call.getType().getPrecision();
+ int scale = call.getType().getScale();
+
+ RexLiteral p =
rexBuilder.makeExactLiteral(BigDecimal.valueOf(precision),
tf.createSqlType(SqlTypeName.INTEGER));
+ RexLiteral s =
rexBuilder.makeExactLiteral(BigDecimal.valueOf(scale),
tf.createSqlType(SqlTypeName.INTEGER));
+
+ sumDivCnt =
rexBuilder.makeCall(IgniteSqlOperatorTable.DECIMAL_DIVIDE, numeratorRef,
denominatorRef, p, s);
+ } else {
+ RexNode divideRef =
rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
+ sumDivCnt = rexBuilder.makeCast(call.getType(), divideRef,
true, false);
+ }
+
+ if (canBeNull) {
+ // CASE cnt == 0 THEN null
+ // OTHERWISE sum / cnt
+ RexLiteral zero = rexBuilder.makeExactLiteral(BigDecimal.ZERO,
denominatorRef.getType());
+ RexNode eqZero =
rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, numeratorRef, zero);
+ RexLiteral nullRes =
rexBuilder.makeNullLiteral(call.getType());
+
+ return rexBuilder.makeCall(SqlStdOperatorTable.CASE, eqZero,
nullRes, sumDivCnt);
+ } else {
+ return sumDivCnt;
+ }
+ };
+
+ List<Integer> argList = List.of(reduceArgumentOffset,
reduceArgumentOffset + 1);
+ return new MapReduceAgg(
+ argList,
+ List.of(mapSum0, mapCount0),
+ reduceInputExpr,
+ List.of(reduceSum0, reduceSumCount),
+ reduceOutputExpr
+ );
+ }
}
diff --git
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/HashAggregateConverterRule.java
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/HashAggregateConverterRule.java
index b14010d030..15d72f5f7f 100644
---
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/HashAggregateConverterRule.java
+++
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/HashAggregateConverterRule.java
@@ -31,9 +31,11 @@ import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexNode;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.ignite.internal.sql.engine.rel.IgniteConvention;
+import org.apache.ignite.internal.sql.engine.rel.IgniteProject;
import org.apache.ignite.internal.sql.engine.rel.IgniteRel;
import
org.apache.ignite.internal.sql.engine.rel.agg.IgniteColocatedHashAggregate;
import org.apache.ignite.internal.sql.engine.rel.agg.IgniteMapHashAggregate;
@@ -106,7 +108,10 @@ public class HashAggregateConverterRule {
RelTraitSet outTrait =
cluster.traitSetOf(IgniteConvention.INSTANCE);
Mapping fieldMappingOnReduce =
Commons.trimmingMapping(agg.getGroupSet().length(), agg.getGroupSet());
+ RelTraitSet reducePhaseTraits =
outTrait.replace(IgniteDistributions.single());
+
AggregateRelBuilder relBuilder = new AggregateRelBuilder() {
+
@Override
public IgniteRel makeMapAgg(RelOptCluster cluster, RelNode
input, ImmutableBitSet groupSet,
List<ImmutableBitSet> groupSets, List<AggregateCall>
aggregateCalls) {
@@ -121,13 +126,26 @@ public class HashAggregateConverterRule {
}
@Override
- public IgniteRel makeReduceAgg(RelOptCluster cluster, RelNode
map, ImmutableBitSet groupSet,
+ public IgniteRel makeProject(RelOptCluster cluster, RelNode
input, List<RexNode> reduceInputExprs,
+ RelDataType projectRowType) {
+
+ return new IgniteProject(
+ agg.getCluster(),
+ reducePhaseTraits,
+ convert(input,
inTrait.replace(IgniteDistributions.single())),
+ reduceInputExprs,
+ projectRowType
+ );
+ }
+
+ @Override
+ public IgniteRel makeReduceAgg(RelOptCluster cluster, RelNode
input, ImmutableBitSet groupSet,
List<ImmutableBitSet> groupSets, List<AggregateCall>
aggregateCalls, RelDataType outputType) {
return new IgniteReduceHashAggregate(
cluster,
- outTrait.replace(IgniteDistributions.single()),
- convert(map,
inTrait.replace(IgniteDistributions.single())),
+ reducePhaseTraits,
+ convert(input,
inTrait.replace(IgniteDistributions.single())),
groupSet,
groupSets,
aggregateCalls,
diff --git
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/SortAggregateConverterRule.java
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/SortAggregateConverterRule.java
index 6dd3c82009..5b90131995 100644
---
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/SortAggregateConverterRule.java
+++
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/SortAggregateConverterRule.java
@@ -34,9 +34,11 @@ import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexNode;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.ignite.internal.sql.engine.rel.IgniteConvention;
+import org.apache.ignite.internal.sql.engine.rel.IgniteProject;
import org.apache.ignite.internal.sql.engine.rel.IgniteRel;
import
org.apache.ignite.internal.sql.engine.rel.agg.IgniteColocatedSortAggregate;
import org.apache.ignite.internal.sql.engine.rel.agg.IgniteMapSortAggregate;
@@ -135,6 +137,7 @@ public class SortAggregateConverterRule {
RelTraitSet outTraits =
cluster.traitSetOf(IgniteConvention.INSTANCE).replace(outputCollation);
AggregateRelBuilder relBuilder = new AggregateRelBuilder() {
+
@Override
public IgniteRel makeMapAgg(RelOptCluster cluster, RelNode
input, ImmutableBitSet groupSet,
List<ImmutableBitSet> groupSets, List<AggregateCall>
aggregateCalls) {
@@ -151,13 +154,25 @@ public class SortAggregateConverterRule {
}
@Override
- public IgniteRel makeReduceAgg(RelOptCluster cluster, RelNode
map, ImmutableBitSet groupSet,
+ public IgniteRel makeProject(RelOptCluster cluster, RelNode
input, List<RexNode> reduceInputExprs,
+ RelDataType projectRowType) {
+
+ return new IgniteProject(agg.getCluster(),
+ outTraits.replace(IgniteDistributions.single()),
+ convert(input,
outTraits.replace(IgniteDistributions.single())),
+ reduceInputExprs,
+ projectRowType
+ );
+ }
+
+ @Override
+ public IgniteRel makeReduceAgg(RelOptCluster cluster, RelNode
input, ImmutableBitSet groupSet,
List<ImmutableBitSet> groupSets, List<AggregateCall>
aggregateCalls, RelDataType outputType) {
return new IgniteReduceSortAggregate(
cluster,
outTraits.replace(IgniteDistributions.single()),
- convert(map,
outTraits.replace(IgniteDistributions.single())),
+ convert(input,
outTraits.replace(IgniteDistributions.single())),
groupSet,
groupSets,
aggregateCalls,
diff --git
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/sql/fun/IgniteSqlOperatorTable.java
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/sql/fun/IgniteSqlOperatorTable.java
index 141bff461a..6d2e733f20 100644
---
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/sql/fun/IgniteSqlOperatorTable.java
+++
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/sql/fun/IgniteSqlOperatorTable.java
@@ -19,6 +19,7 @@ package org.apache.ignite.internal.sql.engine.sql.fun;
import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.SqlBasicFunction;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
@@ -190,6 +191,31 @@ public class IgniteSqlOperatorTable extends
ReflectiveSqlOperatorTable {
OperandTypes.CHARACTER.or(OperandTypes.BINARY),
SqlFunctionCategory.NUMERIC);
+ /**
+ * Division operator used by REDUCE phase of AVG aggregate.
+ * Uses provided values of {@code scale} and {@code precision} to return
inferred type.
+ */
+ public static final SqlFunction DECIMAL_DIVIDE =
SqlBasicFunction.create("DECIMAL_DIVIDE",
+ new SqlReturnTypeInference() {
+ @Override
+ public @Nullable RelDataType
inferReturnType(SqlOperatorBinding opBinding) {
+ RelDataType arg0 = opBinding.getOperandType(0);
+ Integer precision = opBinding.getOperandLiteralValue(2,
Integer.class);
+ Integer scale = opBinding.getOperandLiteralValue(3,
Integer.class);
+
+ assert precision != null : "precision is not specified: "
+ opBinding.getOperator();
+ assert scale != null : "scale is not specified: " +
opBinding.getOperator();
+
+ boolean nullable = arg0.isNullable();
+ RelDataTypeFactory typeFactory =
opBinding.getTypeFactory();
+
+ RelDataType returnType =
typeFactory.createSqlType(SqlTypeName.DECIMAL, precision, scale);
+ return typeFactory.createTypeWithNullability(returnType,
nullable);
+ }
+ },
+ OperandTypes.DIVISION_OPERATOR,
+ SqlFunctionCategory.NUMERIC);
+
/** Singleton instance. */
public static final IgniteSqlOperatorTable INSTANCE = new
IgniteSqlOperatorTable();
@@ -235,6 +261,7 @@ public class IgniteSqlOperatorTable extends
ReflectiveSqlOperatorTable {
register(SqlStdOperatorTable.SUM);
register(SqlStdOperatorTable.SUM0);
register(SqlStdOperatorTable.AVG);
+ register(DECIMAL_DIVIDE);
register(SqlStdOperatorTable.MIN);
register(SqlStdOperatorTable.MAX);
register(SqlStdOperatorTable.ANY_VALUE);
diff --git
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/IgniteMethod.java
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/IgniteMethod.java
index 5d419abbf7..0ac553b9ec 100644
---
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/IgniteMethod.java
+++
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/IgniteMethod.java
@@ -21,6 +21,7 @@ import static
org.apache.ignite.internal.lang.IgniteStringFormatter.format;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
+import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Objects;
import java.util.TimeZone;
@@ -125,7 +126,14 @@ public enum IgniteMethod {
*/
TRUNCATE(IgniteSqlFunctions.class, "struncate", true),
- SUBSTRING(IgniteSqlFunctions.class, "substring", true);
+ SUBSTRING(IgniteSqlFunctions.class, "substring", true),
+
+ /**
+ * Division operator used by REDUCE phase of AVG aggregate.
+ * See {@link IgniteSqlFunctions#decimalDivide(BigDecimal, BigDecimal,
int, int)}.
+ */
+ DECIMAL_DIVIDE(IgniteSqlFunctions.class, "decimalDivide",
BigDecimal.class, BigDecimal.class, int.class, int.class),
+ ;
private final Method method;
diff --git
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/PlanUtils.java
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/PlanUtils.java
index 24e5e49602..90fe27e614 100644
---
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/PlanUtils.java
+++
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/PlanUtils.java
@@ -127,9 +127,17 @@ public class PlanUtils {
for (int i = 0; i < aggregateCalls.size(); i++) {
AggregateCall call = aggregateCalls.get(i);
-
Accumulator acc = accumulators.accumulatorFactory(call).get();
- RelDataType fieldType = acc.returnType(typeFactory);
+ RelDataType fieldType;
+ // For a decimal type Accumulator::returnType returns a type with
default precision and scale,
+ // that can cause precision loss when a tuple is sent over the
wire by an exchanger/outbox.
+ // Outbox uses its input type as wire format, so if a scale is 0,
then the scale is lost
+ // (see Outbox::sendBatch -> RowHandler::toBinaryTuple ->
BinaryTupleBuilder::appendDecimalNotNull).
+ if (call.getType().getSqlTypeName().allowsScale()) {
+ fieldType = call.type;
+ } else {
+ fieldType = acc.returnType(typeFactory);
+ }
String fieldName = "_ACC" + i;
builder.add(fieldName, fieldType);
diff --git
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctionsTest.java
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctionsTest.java
index 5b2b77cb30..813ffb9532 100644
---
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctionsTest.java
+++
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctionsTest.java
@@ -34,6 +34,7 @@ import java.util.function.Supplier;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.ignite.internal.sql.engine.type.IgniteTypeSystem;
import org.apache.ignite.lang.ErrorGroups.Sql;
+import org.jetbrains.annotations.Nullable;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
@@ -516,4 +517,26 @@ public class IgniteSqlFunctionsTest {
assertEquals(Instant.ofEpochMilli(expMillis),
Instant.ofEpochMilli(actualTs));
}
+
+ @ParameterizedTest
+ @CsvSource(
+ value = {
+ "1; 2; 0.5",
+ "1; 3; 0.3333333333333333",
+ "6; 2; 3",
+ "1; 0;",
+ },
+ delimiterString = ";"
+ )
+ public void testAvgDivide(String a, String b, @Nullable String expected) {
+ BigDecimal num = new BigDecimal(a);
+ BigDecimal denum = new BigDecimal(b);
+
+ if (expected != null) {
+ BigDecimal actual = IgniteSqlFunctions.decimalDivide(num, denum,
4, 2);
+ assertEquals(new BigDecimal(expected), actual);
+ } else {
+ assertThrows(ArithmeticException.class, () ->
IgniteSqlFunctions.decimalDivide(num, denum, 4, 2));
+ }
+ }
}
diff --git
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/HashAggregateExecutionTest.java
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/HashAggregateExecutionTest.java
index daaa6e7f25..bb4355ca79 100644
---
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/HashAggregateExecutionTest.java
+++
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/HashAggregateExecutionTest.java
@@ -118,7 +118,13 @@ public class HashAggregateExecutionTest extends
BaseAggregateTest {
ImmutableBitSet grpSet = grpSets.get(0);
Mapping reduceMapping = Commons.trimmingMapping(grpSet.length(),
grpSet);
- MapReduceAgg mapReduceAgg =
MapReduceAggregates.createMapReduceAggCall(call,
reduceMapping.getTargetCount());
+ MapReduceAgg mapReduceAgg = MapReduceAggregates.createMapReduceAggCall(
+ Commons.cluster(),
+ call,
+ reduceMapping.getTargetCount(),
+ inRowType,
+ true
+ );
HashAggregateNode<Object[]> aggRdc = new HashAggregateNode<>(
ctx,
diff --git
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/HashAggregateSingleGroupExecutionTest.java
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/HashAggregateSingleGroupExecutionTest.java
index 249ca1b956..fdce4dbf2d 100644
---
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/HashAggregateSingleGroupExecutionTest.java
+++
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/HashAggregateSingleGroupExecutionTest.java
@@ -202,7 +202,7 @@ public class HashAggregateSingleGroupExecutionTest extends
AbstractExecutionTest
map.register(scan);
RelDataType hashRowType = PlanUtils.createHashAggRowType(grpSets, tf,
rowType, List.of(mapCall));
- MapReduceAgg reduceAggCall =
MapReduceAggregates.createMapReduceAggCall(mapCall, 0);
+ MapReduceAgg reduceAggCall =
MapReduceAggregates.createMapReduceAggCall(Commons.cluster(), mapCall, 0,
rowType, true);
HashAggregateNode<Object[]> reduce = new HashAggregateNode<>(ctx,
REDUCE, grpSets,
accFactory(ctx, reduceAggCall.getReduceCall(), REDUCE,
hashRowType), rowFactory());
@@ -431,7 +431,7 @@ public class HashAggregateSingleGroupExecutionTest extends
AbstractExecutionTest
IgniteTypeFactory tf = Commons.typeFactory();
RelDataType hashRowType = PlanUtils.createHashAggRowType(grpSets, tf,
rowType, List.of(call));
- MapReduceAgg reduceAggCall =
MapReduceAggregates.createMapReduceAggCall(call, 0);
+ MapReduceAgg reduceAggCall =
MapReduceAggregates.createMapReduceAggCall(Commons.cluster(), call, 0, rowType,
true);
return newHashAggNode(ctx, REDUCE, grpSets, hashRowType,
reduceAggCall.getReduceCall());
}
diff --git
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/SortAggregateExecutionTest.java
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/SortAggregateExecutionTest.java
index f40ccda15c..765ac75912 100644
---
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/SortAggregateExecutionTest.java
+++
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/rel/SortAggregateExecutionTest.java
@@ -137,7 +137,13 @@ public class SortAggregateExecutionTest extends
BaseAggregateTest {
}
Mapping mapping = Commons.trimmingMapping(grpSet.length(), grpSet);
- MapReduceAgg mapReduceAgg =
MapReduceAggregates.createMapReduceAggCall(call, mapping.getTargetCount());
+ MapReduceAgg mapReduceAgg = MapReduceAggregates.createMapReduceAggCall(
+ Commons.cluster(),
+ call,
+ mapping.getTargetCount(),
+ inRowType,
+ true
+ );
SortAggregateNode<Object[]> aggRdc = new SortAggregateNode<>(
ctx,
diff --git
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/AbstractPlannerTest.java
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/AbstractPlannerTest.java
index 95ac9c5fda..ec2456f5dd 100644
---
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/AbstractPlannerTest.java
+++
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/AbstractPlannerTest.java
@@ -65,6 +65,7 @@ import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.schema.ColumnStrategy;
import org.apache.calcite.schema.SchemaPlus;
+import org.apache.calcite.sql.SqlExplainFormat;
import org.apache.calcite.sql.SqlExplainLevel;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlNode;
@@ -483,6 +484,9 @@ public abstract class AbstractPlannerTest extends
IgniteAbstractTest {
) throws Exception {
IgniteRel plan = physicalPlan(sql, schemas, hintStrategies, params,
null, disabledRules);
+ String planString = RelOptUtil.dumpPlan("", plan,
SqlExplainFormat.TEXT, SqlExplainLevel.ALL_ATTRIBUTES);
+ log.info("statement: {}\n{}", sql, planString);
+
checkSplitAndSerialization(plan, schemas);
try {
diff --git
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceAggregatesTest.java
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceAggregatesTest.java
index 9a4ded58dc..bd33ecfd99 100644
---
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceAggregatesTest.java
+++
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceAggregatesTest.java
@@ -30,6 +30,7 @@ import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.ImmutableBitSet;
@@ -37,6 +38,7 @@ import org.apache.calcite.util.Pair;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;
import org.apache.calcite.util.mapping.Mappings;
+import org.apache.ignite.internal.sql.engine.rel.IgniteProject;
import org.apache.ignite.internal.sql.engine.rel.IgniteRel;
import org.apache.ignite.internal.sql.engine.rel.IgniteValues;
import org.apache.ignite.internal.sql.engine.rel.agg.MapReduceAggregates;
@@ -117,9 +119,14 @@ public class MapReduceAggregatesTest {
return createOutExpr(cluster, input);
}
+ @Override
+ public IgniteRel makeProject(RelOptCluster cluster, RelNode input,
List<RexNode> reduceInputExprs, RelDataType projectRowType) {
+ return new IgniteProject(cluster, input.getTraitSet(), input,
reduceInputExprs, projectRowType);
+ }
+
@Override
public IgniteRel makeReduceAgg(RelOptCluster cluster,
- RelNode map,
+ RelNode input,
ImmutableBitSet groupSet,
List<ImmutableBitSet> groupSets,
List<AggregateCall> aggregateCalls,
@@ -127,7 +134,7 @@ public class MapReduceAggregatesTest {
collectedGroupSets.add(Pair.of(groupSet, groupSets));
- return createOutExpr(cluster, map);
+ return createOutExpr(cluster, input);
}
private IgniteValues createOutExpr(RelOptCluster cluster, RelNode
input) {
diff --git
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceHashAggregatePlannerTest.java
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceHashAggregatePlannerTest.java
index 7d803c3277..5d1813b1b0 100644
---
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceHashAggregatePlannerTest.java
+++
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceHashAggregatePlannerTest.java
@@ -18,9 +18,6 @@
package org.apache.ignite.internal.sql.engine.planner;
import static java.util.function.Predicate.not;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.containsString;
-import static org.junit.jupiter.api.Assertions.assertThrows;
import java.util.List;
import java.util.Objects;
@@ -34,7 +31,6 @@ import
org.apache.ignite.internal.sql.engine.rel.IgniteExchange;
import org.apache.ignite.internal.sql.engine.rel.IgniteLimit;
import org.apache.ignite.internal.sql.engine.rel.IgniteMergeJoin;
import org.apache.ignite.internal.sql.engine.rel.IgniteProject;
-import org.apache.ignite.internal.sql.engine.rel.IgniteRel;
import org.apache.ignite.internal.sql.engine.rel.IgniteSort;
import org.apache.ignite.internal.sql.engine.rel.agg.IgniteMapHashAggregate;
import org.apache.ignite.internal.sql.engine.rel.agg.IgniteReduceHashAggregate;
@@ -476,42 +472,42 @@ public class MapReduceHashAggregatePlannerTest extends
AbstractAggregatePlannerT
assertPlan(TestCase.CASE_22, nonColocated, disableRules);
assertPlan(TestCase.CASE_22A, nonColocated, disableRules);
-
- Predicate<RelNode> colocated =
hasChildThat(isInstanceOf(IgniteReduceHashAggregate.class)
- .and(in ->
hasAggregates(countReduce).test(in.getAggregateCalls()))
- .and(input(isInstanceOf(IgniteExchange.class)
- .and(input(isInstanceOf(IgniteMapHashAggregate.class)
- .and(in ->
hasAggregates(countMap).test(in.getAggCallList()))
- .and(input(isTableScan("TEST")))
- )
- ))
- ));
-
- assertPlan(TestCase.CASE_22B, colocated, disableRules);
- assertPlan(TestCase.CASE_22C, colocated, disableRules);
+ assertPlan(TestCase.CASE_22B, nonColocated, disableRules);
+ assertPlan(TestCase.CASE_22C, nonColocated, disableRules);
}
/**
- * Validates that AVG can not be used as two phase mode.
- * Should be fixed with TODO
https://issues.apache.org/jira/browse/IGNITE-20009
+ * Validates that AVG aggregate is split into multiple expressions.
*/
@Test
- public void testAvgAgg() {
- RuntimeException e = assertThrows(RuntimeException.class,
- () -> assertPlan(TestCase.CASE_23,
isInstanceOf(IgniteRel.class), disableRules));
- assertThat(e.getMessage(), containsString("There are not enough rules
to produce a node with desired properties"));
-
- e = assertThrows(RuntimeException.class,
- () -> assertPlan(TestCase.CASE_23A,
isInstanceOf(IgniteRel.class), disableRules));
- assertThat(e.getMessage(), containsString("There are not enough rules
to produce a node with desired properties"));
-
- e = assertThrows(RuntimeException.class,
- () -> assertPlan(TestCase.CASE_23B,
isInstanceOf(IgniteRel.class), disableRules));
- assertThat(e.getMessage(), containsString("There are not enough rules
to produce a node with desired properties"));
-
- e = assertThrows(RuntimeException.class,
- () -> assertPlan(TestCase.CASE_23C,
isInstanceOf(IgniteRel.class), disableRules));
- assertThat(e.getMessage(), containsString("There are not enough rules
to produce a node with desired properties"));
+ public void twoPhaseAvgAgg() throws Exception {
+ Predicate<AggregateCall> sumMap = (a) ->
+ Objects.equals(a.getAggregation().getName(), "SUM") &&
a.getArgList().equals(List.of(1));
+
+ Predicate<AggregateCall> countMap = (a) ->
+ Objects.equals(a.getAggregation().getName(), "COUNT") &&
a.getArgList().equals(List.of(1));
+
+ Predicate<AggregateCall> sumReduce = (a) ->
+ Objects.equals(a.getAggregation().getName(), "SUM") &&
a.getArgList().equals(List.of(1));
+
+ Predicate<AggregateCall> sum0Reduce = (a) ->
+ Objects.equals(a.getAggregation().getName(), "$SUM0") &&
a.getArgList().equals(List.of(2));
+
+ Predicate<RelNode> nonColocated =
hasChildThat(isInstanceOf(IgniteReduceHashAggregate.class)
+ .and(in -> hasAggregates(sumReduce,
sum0Reduce).test(in.getAggregateCalls()))
+ .and(input(isInstanceOf(IgniteProject.class)
+ .and(input(isInstanceOf(IgniteExchange.class)
+
.and(hasDistribution(IgniteDistributions.single()))
+
.and(input(isInstanceOf(IgniteMapHashAggregate.class)
+ .and(in ->
hasAggregates(sumMap, countMap).test(in.getAggCallList()))
+ )
+ ))
+ ))));
+
+ assertPlan(TestCase.CASE_23, nonColocated, disableRules);
+ assertPlan(TestCase.CASE_23A, nonColocated, disableRules);
+ assertPlan(TestCase.CASE_23B, nonColocated, disableRules);
+ assertPlan(TestCase.CASE_23C, nonColocated, disableRules);
}
/**
diff --git
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceSortAggregatePlannerTest.java
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceSortAggregatePlannerTest.java
index 88543a92ee..5329cccbc0 100644
---
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceSortAggregatePlannerTest.java
+++
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/MapReduceSortAggregatePlannerTest.java
@@ -19,9 +19,6 @@ package org.apache.ignite.internal.sql.engine.planner;
import static java.util.function.Predicate.not;
import static
org.apache.ignite.internal.sql.engine.trait.IgniteDistributions.single;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.containsString;
-import static org.junit.jupiter.api.Assertions.assertThrows;
import java.util.List;
import java.util.Objects;
@@ -36,7 +33,6 @@ import
org.apache.ignite.internal.sql.engine.rel.IgniteExchange;
import org.apache.ignite.internal.sql.engine.rel.IgniteLimit;
import org.apache.ignite.internal.sql.engine.rel.IgniteMergeJoin;
import org.apache.ignite.internal.sql.engine.rel.IgniteProject;
-import org.apache.ignite.internal.sql.engine.rel.IgniteRel;
import org.apache.ignite.internal.sql.engine.rel.IgniteSort;
import org.apache.ignite.internal.sql.engine.rel.agg.IgniteMapSortAggregate;
import org.apache.ignite.internal.sql.engine.rel.agg.IgniteReduceSortAggregate;
@@ -482,25 +478,38 @@ public class MapReduceSortAggregatePlannerTest extends
AbstractAggregatePlannerT
}
/**
- * Validates that AVG can not be used as two phase mode. Should be fixed
with TODO https://issues.apache.org/jira/browse/IGNITE-20009
+ * Validates that AVG aggregate is split into multiple expressions:
+ * SUM(col) as s, COUNT(col) as c on map phase and then SUM(s)/SUM0(c) on
reduce phase.
*/
@Test
- public void testAvgAgg() {
- RuntimeException e = assertThrows(RuntimeException.class,
- () -> assertPlan(TestCase.CASE_23,
isInstanceOf(IgniteRel.class), disableRules));
- assertThat(e.getMessage(), containsString("There are not enough rules
to produce a node with desired properties"));
-
- e = assertThrows(RuntimeException.class,
- () -> assertPlan(TestCase.CASE_23A,
isInstanceOf(IgniteRel.class), disableRules));
- assertThat(e.getMessage(), containsString("There are not enough rules
to produce a node with desired properties"));
-
- e = assertThrows(RuntimeException.class,
- () -> assertPlan(TestCase.CASE_23B,
isInstanceOf(IgniteRel.class), disableRules));
- assertThat(e.getMessage(), containsString("There are not enough rules
to produce a node with desired properties"));
-
- e = assertThrows(RuntimeException.class,
- () -> assertPlan(TestCase.CASE_23C,
isInstanceOf(IgniteRel.class), disableRules));
- assertThat(e.getMessage(), containsString("There are not enough rules
to produce a node with desired properties"));
+ public void twoPhaseAvgAgg() throws Exception {
+ Predicate<AggregateCall> sumMap = (a) ->
+ Objects.equals(a.getAggregation().getName(), "SUM") &&
a.getArgList().equals(List.of(1));
+
+ Predicate<AggregateCall> countMap = (a) ->
+ Objects.equals(a.getAggregation().getName(), "COUNT") &&
a.getArgList().equals(List.of(1));
+
+ Predicate<AggregateCall> sumReduce = (a) ->
+ Objects.equals(a.getAggregation().getName(), "SUM") &&
a.getArgList().equals(List.of(1));
+
+ Predicate<AggregateCall> sum0Reduce = (a) ->
+ Objects.equals(a.getAggregation().getName(), "$SUM0") &&
a.getArgList().equals(List.of(2));
+
+ Predicate<RelNode> nonColocated =
hasChildThat(isInstanceOf(IgniteReduceSortAggregate.class)
+ .and(in -> hasAggregates(sumReduce,
sum0Reduce).test(in.getAggregateCalls()))
+ .and(input(isInstanceOf(IgniteProject.class)
+ .and(input(isInstanceOf(IgniteExchange.class)
+ .and(hasDistribution(single()))
+ .and(input(isInstanceOf(IgniteMapSortAggregate.class)
+ .and(in -> hasAggregates(sumMap,
countMap).test(in.getAggCallList()))
+ )
+ ))
+ ))));
+
+ assertPlan(TestCase.CASE_23, nonColocated, disableRules);
+ assertPlan(TestCase.CASE_23A, nonColocated, disableRules);
+ assertPlan(TestCase.CASE_23B, nonColocated, disableRules);
+ assertPlan(TestCase.CASE_23C, nonColocated, disableRules);
}
/**
diff --git
a/modules/sql-engine/src/testFixtures/java/org/apache/ignite/internal/sql/engine/util/QueryChecker.java
b/modules/sql-engine/src/testFixtures/java/org/apache/ignite/internal/sql/engine/util/QueryChecker.java
index 36a46d2009..96a1f6e1ee 100644
---
a/modules/sql-engine/src/testFixtures/java/org/apache/ignite/internal/sql/engine/util/QueryChecker.java
+++
b/modules/sql-engine/src/testFixtures/java/org/apache/ignite/internal/sql/engine/util/QueryChecker.java
@@ -187,8 +187,8 @@ public interface QueryChecker {
static void assertEqualsCollections(Collection<?> exp, Collection<?> act) {
if (exp.size() != act.size()) {
String errorMsg = new IgniteStringBuilder("Collections sizes are
not equal:").nl()
- .app("\tExpected: ").app(exp.size()).nl()
- .app("\t Actual: ").app(act.size()).toString();
+ .app("\tExpected: ").app(exp.size()).app(exp).nl()
+ .app("\t Actual: ").app(act.size()).app(act).toString();
throw new AssertionError(errorMsg);
}