This is an automated email from the ASF dual-hosted git repository. dzamo pushed a commit to branch 1.21 in repository https://gitbox.apache.org/repos/asf/drill.git
commit 866f420573e18ce847f2796348d138afe3fe053a Author: Volodymyr Vysotskyi <[email protected]> AuthorDate: Thu Feb 23 10:54:38 2023 +0200 DRILL-8403: Generated aggregate function calls are missing required filters when used with PIVOT (#2765) --- .../planner/logical/DrillReduceAggregatesRule.java | 47 ++++++++++++++------ .../drill/exec/planner/physical/AggPrelBase.java | 51 ++++++++++++++++------ .../drill/exec/fn/impl/TestAggregateFunctions.java | 18 ++++++++ 3 files changed, 89 insertions(+), 27 deletions(-) diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java index 062fda0c34..b386361adf 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java @@ -343,12 +343,10 @@ public class DrillReduceAggregatesRule extends RelOptRule { SqlAggFunction sumAgg = new DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper( new SqlSumEmptyIsZeroAggFunction(), sumType); - AggregateCall sumCall = AggregateCall.create(sumAgg, oldCall.isDistinct(), - oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null); + AggregateCall sumCall = getAggCall(oldCall, sumAgg, sumType); final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT; final RelDataType countType = countAgg.getReturnType(typeFactory); - AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(), - oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null); + AggregateCall countCall = getAggCall(oldCall, countAgg, countType); RexNode tmpsumRef = rexBuilder.addAggCall( @@ -414,6 +412,21 @@ public class DrillReduceAggregatesRule extends RelOptRule { } } + private static AggregateCall getAggCall(AggregateCall oldCall, + SqlAggFunction aggFunction, + RelDataType sumType) { + return AggregateCall.create(aggFunction, + oldCall.isDistinct(), + oldCall.isApproximate(), + oldCall.ignoreNulls(), + oldCall.getArgList(), + oldCall.filterArg, + oldCall.distinctKeys, + oldCall.getCollation(), + sumType, + null); + } + private RexNode reduceSum( Aggregate oldAggRel, AggregateCall oldCall, @@ -441,12 +454,10 @@ public class DrillReduceAggregatesRule extends RelOptRule { } sumZeroAgg = new DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper( new SqlSumEmptyIsZeroAggFunction(), sumType); - AggregateCall sumZeroCall = AggregateCall.create(sumZeroAgg, oldCall.isDistinct(), - oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null); + AggregateCall sumZeroCall = getAggCall(oldCall, sumZeroAgg, sumType); final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT; final RelDataType countType = countAgg.getReturnType(typeFactory); - AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(), - oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null); + AggregateCall countCall = getAggCall(oldCall, countAgg, countType); // NOTE: these references are with respect to the output // of newAggRel RexNode sumZeroRef = @@ -529,8 +540,11 @@ public class DrillReduceAggregatesRule extends RelOptRule { new SqlSumAggFunction(sumType), sumType), oldCall.isDistinct(), oldCall.isApproximate(), + oldCall.ignoreNulls(), ImmutableIntList.of(argSquaredOrdinal), - -1, + oldCall.filterArg, + oldCall.distinctKeys, + oldCall.getCollation(), sumType, null); final RexNode sumArgSquared = @@ -547,8 +561,11 @@ public class DrillReduceAggregatesRule extends RelOptRule { new SqlSumAggFunction(sumType), sumType), oldCall.isDistinct(), oldCall.isApproximate(), + oldCall.ignoreNulls(), ImmutableIntList.of(argOrdinal), - -1, + oldCall.filterArg, + oldCall.distinctKeys, + oldCall.getCollation(), sumType, null); final RexNode sumArg = @@ -565,8 +582,7 @@ public class DrillReduceAggregatesRule extends RelOptRule { final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT; final RelDataType countType = countAgg.getReturnType(typeFactory); - final AggregateCall countArgAggCall = AggregateCall.create(countAgg, oldCall.isDistinct(), - oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null); + final AggregateCall countArgAggCall = getAggCall(oldCall, countAgg, countType); final RexNode countArg = rexBuilder.addAggCall( countArgAggCall, @@ -677,7 +693,7 @@ public class DrillReduceAggregatesRule extends RelOptRule { RelNode inputRel, List<AggregateCall> newCalls) { RelOptCluster cluster = inputRel.getCluster(); - return new LogicalAggregate(cluster, cluster.traitSetOf(Convention.NONE), + return new LogicalAggregate(cluster, cluster.traitSetOf(Convention.NONE), Collections.emptyList(), inputRel, oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), newCalls); } @@ -722,8 +738,11 @@ public class DrillReduceAggregatesRule extends RelOptRule { sumZeroAgg, oldAggregateCall.isDistinct(), oldAggregateCall.isApproximate(), + oldAggregateCall.ignoreNulls(), oldAggregateCall.getArgList(), - -1, + oldAggregateCall.filterArg, + oldAggregateCall.distinctKeys, + oldAggregateCall.getCollation(), sumType, oldAggregateCall.getName()); oldAggRel.getCluster().getRexBuilder() diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java index f9a7d0e099..a8619aa8d3 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java @@ -33,6 +33,7 @@ import org.apache.drill.exec.planner.common.DrillAggregateRelBase; import org.apache.drill.exec.planner.physical.visitor.PrelVisitor; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.InvalidRelException; +import org.apache.calcite.rel.RelCollations; import org.apache.calcite.rel.RelNode; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; @@ -42,6 +43,7 @@ import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.util.Optionality; import java.util.Collections; import java.util.Iterator; @@ -61,7 +63,7 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel // phase PHASE_2of2("2nd"); - private String name; + private final String name; OperatorPhase(String name) { this.name = name; @@ -99,7 +101,7 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel * creating a SUM whose return type is non-nullable. * */ - public class SqlSumCountAggFunction extends SqlAggFunction { + public static class SqlSumCountAggFunction extends SqlAggFunction { private final RelDataType type; @@ -112,7 +114,8 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel OperandTypes.NUMERIC, SqlFunctionCategory.NUMERIC, false, - false); + false, + Optionality.FORBIDDEN); this.type = type; } @@ -175,8 +178,11 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel sumAggFun, aggCall.e.isDistinct(), aggCall.e.isApproximate(), + false, Collections.singletonList(aggExprOrdinal), aggCall.e.filterArg, + null, + RelCollations.EMPTY, aggCall.e.getType(), aggCall.e.getName()); @@ -187,8 +193,11 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel aggCall.e.getAggregation(), aggCall.e.isDistinct(), aggCall.e.isApproximate(), + false, Collections.singletonList(aggExprOrdinal), aggCall.e.filterArg, + null, + RelCollations.EMPTY, aggCall.e.getType(), aggCall.e.getName()); @@ -202,21 +211,29 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel List<LogicalExpression> args = Lists.newArrayList(); for (Integer i : call.getArgList()) { LogicalExpression expr = FieldReference.getWithQuotedRef(fn.get(i)); - if (call.hasFilter()) { - expr = IfExpression.newBuilder() - .setIfCondition(new IfExpression.IfCondition(FieldReference.getWithQuotedRef(fn.get(call.filterArg)), expr)) - .setElse(NullExpression.INSTANCE) - .build(); - } + expr = getArgumentExpression(call, fn, expr); args.add(expr); } if (SqlKind.COUNT.name().equals(call.getAggregation().getName()) && args.isEmpty()) { - args.add(new ValueExpressions.LongExpression(1L)); + LogicalExpression expr = new ValueExpressions.LongExpression(1L); + expr = getArgumentExpression(call, fn, expr); + args.add(expr); } return new FunctionCall(call.getAggregation().getName().toLowerCase(), args, ExpressionPosition.UNKNOWN); } + private static LogicalExpression getArgumentExpression(AggregateCall call, List<String> fn, + LogicalExpression expr) { + if (call.hasFilter()) { + return IfExpression.newBuilder() + .setIfCondition(new IfExpression.IfCondition(FieldReference.getWithQuotedRef(fn.get(call.filterArg)), expr)) + .setElse(NullExpression.INSTANCE) + .build(); + } + return expr; + } + @Override public Iterator<Prel> iterator() { return PrelUtil.iter(getInput()); @@ -249,9 +266,17 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel for (int arg : aggCall.getArgList()) { arglist.add(arg + 1); } - aggregateCalls.add(AggregateCall.create(aggCall.getAggregation(), aggCall.isDistinct(), - aggCall.isApproximate(), arglist, aggCall.filterArg, aggCall.type, aggCall.name)); + aggregateCalls.add(AggregateCall.create(aggCall.getAggregation(), + aggCall.isDistinct(), + aggCall.isApproximate(), + false, + arglist, + aggCall.filterArg, + null, + RelCollations.EMPTY, + aggCall.type, + aggCall.name)); } - return (Prel) copy(traitSet, children.get(0),indicator,groupingSet,groupingSets, aggregateCalls); + return (Prel) copy(traitSet, children.get(0), groupingSet, groupingSets, aggregateCalls); } } diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunctions.java b/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunctions.java index 97f3c254b2..edf619074b 100644 --- a/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunctions.java +++ b/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunctions.java @@ -1271,4 +1271,22 @@ public class TestAggregateFunctions extends ClusterTest { .baselineValues(5L, 5L, 5L, 5L, 5L) .go(); } + + @Test + public void testAggregateWithPivot() throws Exception { + String query = "SELECT * FROM (\n" + + "SELECT education_level, salary, marital_status, extract(year from age(birth_date)) age\n" + + "FROM cp.`employee.json`)\n" + + "PIVOT (avg(salary) avg_salary, avg(age) avg_age FOR marital_status IN ('M' married, 'S' single))"; + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("education_level", "married_avg_salary", "married_avg_age", "single_avg_salary", "single_avg_age") + .baselineValues("Graduate Degree", 4038.470588235294, 101.98823529411764, 4747.176470588235, 98.65882352941176) + .baselineValues("Bachelors Degree", 4789.166666666667, 102.43055555555556, 4193.566433566433, 102.02797202797203) + .baselineValues("Partial College", 4281.381578947368, 99.25657894736842, 3785.294117647059, 101.04411764705883) + .baselineValues("High School Degree", 3459.2805755395684, 103.57553956834532, 3571.830985915493, 102.69014084507042) + .baselineValues("Partial High School", 3555.8064516129034, 101.14516129032258, 3469.7014925373132, 103.3731343283582) + .go(); + } }
