DRILL-3993: Fix aggregate exchange rules for the cases when aggregate rel node contains several calls
Project: http://git-wip-us.apache.org/repos/asf/drill/repo Commit: http://git-wip-us.apache.org/repos/asf/drill/commit/3c9093e3 Tree: http://git-wip-us.apache.org/repos/asf/drill/tree/3c9093e3 Diff: http://git-wip-us.apache.org/repos/asf/drill/diff/3c9093e3 Branch: refs/heads/master Commit: 3c9093e32a095bd40832bcd8fe67ab20898537c4 Parents: 22d0f7e Author: Volodymyr Vysotskyi <vvo...@gmail.com> Authored: Thu Jan 4 16:05:53 2018 +0200 Committer: Volodymyr Vysotskyi <vvo...@gmail.com> Committed: Tue Jan 16 12:10:13 2018 +0200 ---------------------------------------------------------------------- .../exec/planner/physical/AggPrelBase.java | 24 ++++++++++-------- .../exec/planner/physical/AggPruleBase.java | 26 ++++++++++++++++---- .../exec/planner/physical/HashAggPrule.java | 22 +++++++++++------ .../exec/planner/physical/StreamAggPrule.java | 25 ++++++++++++++----- 4 files changed, 69 insertions(+), 28 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/drill/blob/3c9093e3/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java ---------------------------------------------------------------------- diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java index 8c69930..232473b 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPrelBase.java @@ -1,4 +1,4 @@ -/** +/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -31,7 +31,6 @@ import org.apache.drill.common.logical.data.NamedExpression; import org.apache.drill.exec.planner.common.DrillAggregateRelBase; import org.apache.drill.exec.planner.physical.visitor.PrelVisitor; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.InvalidRelException; import org.apache.calcite.rel.RelNode; import org.apache.calcite.plan.RelOptCluster; @@ -44,14 +43,13 @@ import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; -import java.util.BitSet; import java.util.Collections; import java.util.Iterator; import java.util.List; public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel { - public static enum OperatorPhase {PHASE_1of1, PHASE_1of2, PHASE_2of2}; + public enum OperatorPhase {PHASE_1of1, PHASE_1of2, PHASE_2of2} protected OperatorPhase operPhase = OperatorPhase.PHASE_1of1 ; // default phase protected List<NamedExpression> keys = Lists.newArrayList(); @@ -70,11 +68,14 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel public SqlSumCountAggFunction(RelDataType type) { super("$SUM0", + null, SqlKind.OTHER_FUNCTION, ReturnTypes.BIGINT, // use the inferred return type of SqlCountAggFunction null, OperandTypes.NUMERIC, - SqlFunctionCategory.NUMERIC); + SqlFunctionCategory.NUMERIC, + false, + false); this.type = type; } @@ -143,20 +144,24 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel // If we are doing a COUNT aggregate in Phase1of2, then in Phase2of2 we should SUM the COUNTs, SqlAggFunction sumAggFun = new SqlSumCountAggFunction(aggCall.e.getType()); AggregateCall newAggCall = - new AggregateCall( + AggregateCall.create( sumAggFun, aggCall.e.isDistinct(), + aggCall.e.isApproximate(), Collections.singletonList(aggExprOrdinal), + aggCall.e.filterArg, aggCall.e.getType(), aggCall.e.getName()); phase2AggCallList.add(newAggCall); } else { AggregateCall newAggCall = - new AggregateCall( + AggregateCall.create( aggCall.e.getAggregation(), aggCall.e.isDistinct(), + aggCall.e.isApproximate(), Collections.singletonList(aggExprOrdinal), + aggCall.e.filterArg, aggCall.e.getType(), aggCall.e.getName()); @@ -174,10 +179,9 @@ public abstract class AggPrelBase extends DrillAggregateRelBase implements Prel // for count(1). if (args.isEmpty()) { - args.add(new ValueExpressions.LongExpression(1l)); + args.add(new ValueExpressions.LongExpression(1L)); } - LogicalExpression expr = new FunctionCall(call.getAggregation().getName().toLowerCase(), args, ExpressionPosition.UNKNOWN ); - return expr; + return new FunctionCall(call.getAggregation().getName().toLowerCase(), args, ExpressionPosition.UNKNOWN); } @Override http://git-wip-us.apache.org/repos/asf/drill/blob/3c9093e3/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPruleBase.java ---------------------------------------------------------------------- diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPruleBase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPruleBase.java index 84e37fc..6863967 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPruleBase.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/AggPruleBase.java @@ -21,8 +21,8 @@ package org.apache.drill.exec.planner.physical; import java.util.List; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.util.BitSets; +import org.apache.calcite.util.ImmutableBitSet; import org.apache.drill.exec.planner.logical.DrillAggregateRel; import org.apache.drill.exec.planner.physical.DrillDistributionTrait.DistributionField; import org.apache.calcite.rel.core.AggregateCall; @@ -42,7 +42,7 @@ public abstract class AggPruleBase extends Prule { protected List<DistributionField> getDistributionField(DrillAggregateRel rel, boolean allFields) { List<DistributionField> groupByFields = Lists.newArrayList(); - for (int group : BitSets.toIter(rel.getGroupSet())) { + for (int group : remapGroupSet(rel.getGroupSet())) { DistributionField field = new DistributionField(group); groupByFields.add(field); @@ -63,10 +63,11 @@ public abstract class AggPruleBase extends Prule { protected boolean create2PhasePlan(RelOptRuleCall call, DrillAggregateRel aggregate) { PlannerSettings settings = PrelUtil.getPlannerSettings(call.getPlanner()); RelNode child = call.rel(0).getInputs().get(0); - boolean smallInput = child.getRows() < settings.getSliceTarget(); - if (! settings.isMultiPhaseAggEnabled() || settings.isSingleMode() || + boolean smallInput = + child.estimateRowCount(child.getCluster().getMetadataQuery()) < settings.getSliceTarget(); + if (!settings.isMultiPhaseAggEnabled() || settings.isSingleMode() // Can override a small child - e.g., for testing with a small table - ( smallInput && ! settings.isForce2phaseAggr() ) ) { + || (smallInput && !settings.isForce2phaseAggr())) { return false; } @@ -82,4 +83,19 @@ public abstract class AggPruleBase extends Prule { } return true; } + + /** + * Returns group-by keys with the remapped arguments for specified aggregate. + * + * @param groupSet ImmutableBitSet of aggregate rel node, whose group-by keys should be remapped. + * @return {@link ImmutableBitSet} instance with remapped keys. + */ + public static ImmutableBitSet remapGroupSet(ImmutableBitSet groupSet) { + List<Integer> newGroupSet = Lists.newArrayList(); + int groupSetToAdd = 0; + for (int ignored : groupSet) { + newGroupSet.add(groupSetToAdd++); + } + return ImmutableBitSet.of(newGroupSet); + } } http://git-wip-us.apache.org/repos/asf/drill/blob/3c9093e3/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/HashAggPrule.java ---------------------------------------------------------------------- diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/HashAggPrule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/HashAggPrule.java index f4cdf62..02dd4de 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/HashAggPrule.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/HashAggPrule.java @@ -17,6 +17,8 @@ */ package org.apache.drill.exec.planner.physical; +import com.google.common.collect.Lists; +import org.apache.calcite.util.ImmutableBitSet; import org.apache.drill.exec.planner.logical.DrillAggregateRel; import org.apache.drill.exec.planner.logical.RelOptHelper; import org.apache.drill.exec.planner.physical.AggPrelBase.OperatorPhase; @@ -31,6 +33,8 @@ import org.apache.calcite.util.trace.CalciteTrace; import com.google.common.collect.ImmutableList; import org.slf4j.Logger; +import java.util.List; + public class HashAggPrule extends AggPruleBase { public static final RelOptRule INSTANCE = new HashAggPrule(); protected static final Logger tracer = CalciteTrace.getPlannerTracer(); @@ -51,7 +55,7 @@ public class HashAggPrule extends AggPruleBase { return; } - final DrillAggregateRel aggregate = (DrillAggregateRel) call.rel(0); + final DrillAggregateRel aggregate = call.rel(0); final RelNode input = call.rel(1); if (aggregate.containsDistinctCall() || aggregate.getGroupCount() == 0) { @@ -60,7 +64,7 @@ public class HashAggPrule extends AggPruleBase { return; } - RelTraitSet traits = null; + RelTraitSet traits; try { if (aggregate.getGroupSet().isEmpty()) { @@ -125,18 +129,22 @@ public class HashAggPrule extends AggPruleBase { new HashToRandomExchangePrel(phase1Agg.getCluster(), phase1Agg.getTraitSet().plus(Prel.DRILL_PHYSICAL).plus(distOnAllKeys), phase1Agg, ImmutableList.copyOf(getDistributionField(aggregate, true))); - HashAggPrel phase2Agg = new HashAggPrel( + ImmutableBitSet newGroupSet = remapGroupSet(aggregate.getGroupSet()); + List<ImmutableBitSet> newGroupSets = Lists.newArrayList(); + for (ImmutableBitSet groupSet : aggregate.getGroupSets()) { + newGroupSets.add(remapGroupSet(groupSet)); + } + + return new HashAggPrel( aggregate.getCluster(), exch.getTraitSet(), exch, aggregate.indicator, - aggregate.getGroupSet(), - aggregate.getGroupSets(), + newGroupSet, + newGroupSets, phase1Agg.getPhase2AggCalls(), OperatorPhase.PHASE_2of2); - return phase2Agg; } - } private void createTransformRequest(RelOptRuleCall call, DrillAggregateRel aggregate, http://git-wip-us.apache.org/repos/asf/drill/blob/3c9093e3/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java ---------------------------------------------------------------------- diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java index a6a8f28..29fa750 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java @@ -21,6 +21,7 @@ import java.util.List; import org.apache.calcite.util.BitSets; +import org.apache.calcite.util.ImmutableBitSet; import org.apache.drill.exec.planner.logical.DrillAggregateRel; import org.apache.drill.exec.planner.logical.RelOptHelper; import org.apache.drill.exec.planner.physical.AggPrelBase.OperatorPhase; @@ -53,10 +54,10 @@ public class StreamAggPrule extends AggPruleBase { @Override public void onMatch(RelOptRuleCall call) { - final DrillAggregateRel aggregate = (DrillAggregateRel) call.rel(0); + final DrillAggregateRel aggregate = call.rel(0); RelNode input = aggregate.getInput(); final RelCollation collation = getCollation(aggregate); - RelTraitSet traits = null; + RelTraitSet traits; if (aggregate.containsDistinctCall()) { // currently, don't use StreamingAggregate if any of the logical aggrs contains DISTINCT @@ -93,13 +94,19 @@ public class StreamAggPrule extends AggPruleBase { UnionExchangePrel exch = new UnionExchangePrel(phase1Agg.getCluster(), singleDistTrait, phase1Agg); + ImmutableBitSet newGroupSet = remapGroupSet(aggregate.getGroupSet()); + List<ImmutableBitSet> newGroupSets = Lists.newArrayList(); + for (ImmutableBitSet groupSet : aggregate.getGroupSets()) { + newGroupSets.add(remapGroupSet(groupSet)); + } + return new StreamAggPrel( aggregate.getCluster(), singleDistTrait, exch, aggregate.indicator, - aggregate.getGroupSet(), - aggregate.getGroupSets(), + newGroupSet, + newGroupSets, phase1Agg.getPhase2AggCalls(), OperatorPhase.PHASE_2of2); } @@ -160,13 +167,19 @@ public class StreamAggPrule extends AggPruleBase { collation, numEndPoints); + ImmutableBitSet newGroupSet = remapGroupSet(aggregate.getGroupSet()); + List<ImmutableBitSet> newGroupSets = Lists.newArrayList(); + for (ImmutableBitSet groupSet : aggregate.getGroupSets()) { + newGroupSets.add(remapGroupSet(groupSet)); + } + return new StreamAggPrel( aggregate.getCluster(), exch.getTraitSet(), exch, aggregate.indicator, - aggregate.getGroupSet(), - aggregate.getGroupSets(), + newGroupSet, + newGroupSets, phase1Agg.getPhase2AggCalls(), OperatorPhase.PHASE_2of2); }