[FLINK-5144] [table] Fix error while applying rule AggregateJoinTransposeRule

This closes #3062.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/e187b5ee
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/e187b5ee
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/e187b5ee

Branch: refs/heads/master
Commit: e187b5ee9aa6d1bf9feec151ff460d1a28c4e5f0
Parents: dba7d7d
Author: Kurt Young <ykt...@gmail.com>
Authored: Thu Jan 5 11:32:04 2017 +0800
Committer: twalthr <twal...@apache.org>
Committed: Tue Jan 17 14:44:27 2017 +0100

----------------------------------------------------------------------
 .../rules/FlinkAggregateJoinTransposeRule.java  |  346 +++
 .../calcite/sql2rel/FlinkRelDecorrelator.java   | 2216 ++++++++++++++++++
 .../flink/table/calcite/FlinkPlannerImpl.scala  |    7 +-
 .../flink/table/plan/rules/FlinkRuleSets.scala  |    5 +-
 .../batch/sql/QueryDecorrelationTest.scala      |  218 ++
 5 files changed, 2787 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/e187b5ee/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateJoinTransposeRule.java
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateJoinTransposeRule.java
 
b/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateJoinTransposeRule.java
new file mode 100644
index 0000000..ac36b3c
--- /dev/null
+++ 
b/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateJoinTransposeRule.java
@@ -0,0 +1,346 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.calcite.rules;
+
+import com.google.common.base.Function;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import org.apache.calcite.linq4j.Ord;
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Join;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.core.RelFactories;
+import org.apache.calcite.rel.logical.LogicalAggregate;
+import org.apache.calcite.rel.logical.LogicalJoin;
+import org.apache.calcite.rel.metadata.RelMetadataQuery;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.SqlSplittableAggFunction;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.tools.RelBuilderFactory;
+import org.apache.calcite.util.Bug;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.Util;
+import org.apache.calcite.util.mapping.Mapping;
+import org.apache.calcite.util.mapping.Mappings;
+
+import java.util.ArrayList;
+import java.util.BitSet;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.TreeMap;
+import org.apache.flink.util.Preconditions;
+
+/**
+ * Copied from {@link 
org.apache.calcite.rel.rules.AggregateJoinTransposeRule}, should be
+ * removed once <a 
href="https://issues.apache.org/jira/browse/CALCITE-1544";>[CALCITE-1544] fixes.
+ */
+public class FlinkAggregateJoinTransposeRule extends RelOptRule {
+       public static final FlinkAggregateJoinTransposeRule INSTANCE = new 
FlinkAggregateJoinTransposeRule(LogicalAggregate.class, LogicalJoin.class, 
RelFactories.LOGICAL_BUILDER, false);
+
+       /**
+        * Extended instance of the rule that can push down aggregate functions.
+        */
+       public static final FlinkAggregateJoinTransposeRule EXTENDED = new 
FlinkAggregateJoinTransposeRule(LogicalAggregate.class, LogicalJoin.class, 
RelFactories.LOGICAL_BUILDER, true);
+
+       private final boolean allowFunctions;
+
+       /**
+        * Creates an FlinkAggregateJoinTransposeRule.
+        */
+       public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> 
aggregateClass, Class<? extends Join> joinClass, RelBuilderFactory 
relBuilderFactory, boolean allowFunctions) {
+               super(operand(aggregateClass, null, Aggregate.IS_SIMPLE, 
operand(joinClass, any())), relBuilderFactory, null);
+               this.allowFunctions = allowFunctions;
+       }
+
+       @Deprecated // to be removed before 2.0
+       public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> 
aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends 
Join> joinClass, RelFactories.JoinFactory joinFactory) {
+               this(aggregateClass, joinClass, 
RelBuilder.proto(aggregateFactory, joinFactory), false);
+       }
+
+       @Deprecated // to be removed before 2.0
+       public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> 
aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends 
Join> joinClass, RelFactories.JoinFactory joinFactory, boolean allowFunctions) {
+               this(aggregateClass, joinClass, 
RelBuilder.proto(aggregateFactory, joinFactory), allowFunctions);
+       }
+
+       @Deprecated // to be removed before 2.0
+       public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> 
aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends 
Join> joinClass, RelFactories.JoinFactory joinFactory, 
RelFactories.ProjectFactory projectFactory) {
+               this(aggregateClass, joinClass, 
RelBuilder.proto(aggregateFactory, joinFactory, projectFactory), false);
+       }
+
+       @Deprecated // to be removed before 2.0
+       public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> 
aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends 
Join> joinClass, RelFactories.JoinFactory joinFactory, 
RelFactories.ProjectFactory projectFactory, boolean allowFunctions) {
+               this(aggregateClass, joinClass, 
RelBuilder.proto(aggregateFactory, joinFactory, projectFactory), 
allowFunctions);
+       }
+
+       public void onMatch(RelOptRuleCall call) {
+               final Aggregate aggregate = call.rel(0);
+               final Join join = call.rel(1);
+               final RexBuilder rexBuilder = 
aggregate.getCluster().getRexBuilder();
+               final RelBuilder relBuilder = call.builder();
+
+               // If any aggregate functions do not support splitting, bail out
+               // If any aggregate call has a filter, bail out
+               for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
+                       if 
(aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) 
{
+                               return;
+                       }
+                       if (aggregateCall.filterArg >= 0) {
+                               return;
+                       }
+               }
+
+               // If it is not an inner join, we do not push the
+               // aggregate operator
+               if (join.getJoinType() != JoinRelType.INNER) {
+                       return;
+               }
+
+               if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
+                       return;
+               }
+
+               // Do the columns used by the join appear in the output of the 
aggregate?
+               final ImmutableBitSet aggregateColumns = 
aggregate.getGroupSet();
+               final RelMetadataQuery mq = RelMetadataQuery.instance();
+               final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, 
mq.getPulledUpPredicates(join).pulledUpPredicates);
+               final ImmutableBitSet joinColumns = 
RelOptUtil.InputFinder.bits(join.getCondition());
+               final boolean allColumnsInAggregate = 
keyColumns.contains(joinColumns);
+               final ImmutableBitSet belowAggregateColumns = 
aggregateColumns.union(joinColumns);
+
+               // Split join condition
+               final List<Integer> leftKeys = Lists.newArrayList();
+               final List<Integer> rightKeys = Lists.newArrayList();
+               final List<Boolean> filterNulls = Lists.newArrayList();
+               RexNode nonEquiConj = 
RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), 
join.getCondition(), leftKeys, rightKeys, filterNulls);
+               // If it contains non-equi join conditions, we bail out
+               if (!nonEquiConj.isAlwaysTrue()) {
+                       return;
+               }
+
+               // Push each aggregate function down to each side that contains 
all of its
+               // arguments. Note that COUNT(*), because it has no arguments, 
can go to
+               // both sides.
+               final Map<Integer, Integer> map = new HashMap<>();
+               final List<Side> sides = new ArrayList<>();
+               int uniqueCount = 0;
+               int offset = 0;
+               int belowOffset = 0;
+               for (int s = 0; s < 2; s++) {
+                       final Side side = new Side();
+                       final RelNode joinInput = join.getInput(s);
+                       int fieldCount = joinInput.getRowType().getFieldCount();
+                       final ImmutableBitSet fieldSet = 
ImmutableBitSet.range(offset, offset + fieldCount);
+                       final ImmutableBitSet belowAggregateKeyNotShifted = 
belowAggregateColumns.intersect(fieldSet);
+                       for (Ord<Integer> c : 
Ord.zip(belowAggregateKeyNotShifted)) {
+                               map.put(c.e, belowOffset + c.i);
+                       }
+                       final ImmutableBitSet belowAggregateKey = 
belowAggregateKeyNotShifted.shift(-offset);
+                       final boolean unique;
+                       if (!allowFunctions) {
+                               assert aggregate.getAggCallList().isEmpty();
+                               // If there are no functions, it doesn't matter 
as much whether we
+                               // aggregate the inputs before the join, 
because there will not be
+                               // any functions experiencing a cartesian 
product effect.
+                               //
+                               // But finding out whether the input is already 
unique requires a call
+                               // to areColumnsUnique that currently (until 
[CALCITE-1048] "Make
+                               // metadata more robust" is fixed) places a 
heavy load on
+                               // the metadata system.
+                               //
+                               // So we choose to imagine the the input is 
already unique, which is
+                               // untrue but harmless.
+                               //
+                               Util.discard(Bug.CALCITE_1048_FIXED);
+                               unique = true;
+                       } else {
+                               final Boolean unique0 = 
mq.areColumnsUnique(joinInput, belowAggregateKey);
+                               unique = unique0 != null && unique0;
+                       }
+                       if (unique) {
+                               ++uniqueCount;
+                               side.aggregate = false;
+                               side.newInput = joinInput;
+                       } else {
+                               side.aggregate = true;
+                               List<AggregateCall> belowAggCalls = new 
ArrayList<>();
+                               final 
SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = 
registry(belowAggCalls);
+                               final Mappings.TargetMapping mapping = s == 0 ? 
Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + 
offset, 0, offset, fieldCount);
+                               for (Ord<AggregateCall> aggCall : 
Ord.zip(aggregate.getAggCallList())) {
+                                       final SqlAggFunction aggregation = 
aggCall.e.getAggregation();
+                                       final SqlSplittableAggFunction splitter 
= 
Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
+                                       final AggregateCall call1;
+                                       if 
(fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
+                                               call1 = 
splitter.split(aggCall.e, mapping);
+                                       } else {
+                                               call1 = 
splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
+                                       }
+                                       if (call1 != null) {
+                                               side.split.put(aggCall.i, 
belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
+                                       }
+                               }
+                               side.newInput = 
relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, 
false, null), belowAggCalls).build();
+                       }
+                       offset += fieldCount;
+                       belowOffset += 
side.newInput.getRowType().getFieldCount();
+                       sides.add(side);
+               }
+
+               if (uniqueCount == 2) {
+                       // Both inputs to the join are unique. There is nothing 
to be gained by
+                       // this rule. In fact, this aggregate+join may be the 
result of a previous
+                       // invocation of this rule; if we continue we might 
loop forever.
+                       return;
+               }
+
+               // Update condition
+               final Mapping mapping = (Mapping) Mappings.target(new 
Function<Integer, Integer>() {
+                       public Integer apply(Integer a0) {
+                               return map.get(a0);
+                       }
+               }, join.getRowType().getFieldCount(), belowOffset);
+               final RexNode newCondition = RexUtil.apply(mapping, 
join.getCondition());
+
+               // Create new join
+               
relBuilder.push(sides.get(0).newInput).push(sides.get(1).newInput).join(join.getJoinType(),
 newCondition);
+
+               // Aggregate above to sum up the sub-totals
+               final List<AggregateCall> newAggCalls = new ArrayList<>();
+               final int groupIndicatorCount = aggregate.getGroupCount() + 
aggregate.getIndicatorCount();
+               final int newLeftWidth = 
sides.get(0).newInput.getRowType().getFieldCount();
+               final List<RexNode> projects = new 
ArrayList<>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
+               for (Ord<AggregateCall> aggCall : 
Ord.zip(aggregate.getAggCallList())) {
+                       final SqlAggFunction aggregation = 
aggCall.e.getAggregation();
+                       final SqlSplittableAggFunction splitter = 
Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
+                       final Integer leftSubTotal = 
sides.get(0).split.get(aggCall.i);
+                       final Integer rightSubTotal = 
sides.get(1).split.get(aggCall.i);
+                       newAggCalls.add(splitter.topSplit(rexBuilder, 
registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), 
aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 
: rightSubTotal + newLeftWidth));
+               }
+
+               relBuilder.project(projects);
+
+               boolean aggConvertedToProjects = false;
+               if (allColumnsInAggregate) {
+                       // let's see if we can convert aggregate into projects
+                       List<RexNode> projects2 = new ArrayList<>();
+                       for (int key : Mappings.apply(mapping, 
aggregate.getGroupSet())) {
+                               projects2.add(relBuilder.field(key));
+                       }
+                       for (AggregateCall newAggCall : newAggCalls) {
+                               final SqlSplittableAggFunction splitter = 
newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
+                               if (splitter != null) {
+                                       
projects2.add(splitter.singleton(rexBuilder, relBuilder.peek().getRowType(), 
newAggCall));
+                               }
+                       }
+                       if (projects2.size() == 
aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
+                               // We successfully converted agg calls into 
projects.
+                               relBuilder.project(projects2);
+                               aggConvertedToProjects = true;
+                       }
+               }
+
+               if (!aggConvertedToProjects) {
+                       
relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, 
aggregate.getGroupSet()), aggregate.indicator, Mappings.apply2(mapping, 
aggregate.getGroupSets())), newAggCalls);
+               }
+
+               call.transformTo(relBuilder.build());
+       }
+
+       /**
+        * Computes the closure of a set of columns according to a given list of
+        * constraints. Each 'x = y' constraint causes bit y to be set if bit x 
is
+        * set, and vice versa.
+        */
+       private static ImmutableBitSet keyColumns(ImmutableBitSet 
aggregateColumns, ImmutableList<RexNode> predicates) {
+               SortedMap<Integer, BitSet> equivalence = new TreeMap<>();
+               for (RexNode pred : predicates) {
+                       populateEquivalences(equivalence, pred);
+               }
+               ImmutableBitSet keyColumns = aggregateColumns;
+               for (Integer aggregateColumn : aggregateColumns) {
+                       final BitSet bitSet = equivalence.get(aggregateColumn);
+                       if (bitSet != null) {
+                               keyColumns = keyColumns.union(bitSet);
+                       }
+               }
+               return keyColumns;
+       }
+
+       private static void populateEquivalences(Map<Integer, BitSet> 
equivalence, RexNode predicate) {
+               switch (predicate.getKind()) {
+                       case EQUALS:
+                               RexCall call = (RexCall) predicate;
+                               final List<RexNode> operands = 
call.getOperands();
+                               if (operands.get(0) instanceof RexInputRef) {
+                                       final RexInputRef ref0 = (RexInputRef) 
operands.get(0);
+                                       if (operands.get(1) instanceof 
RexInputRef) {
+                                               final RexInputRef ref1 = 
(RexInputRef) operands.get(1);
+                                               
populateEquivalence(equivalence, ref0.getIndex(), ref1.getIndex());
+                                               
populateEquivalence(equivalence, ref1.getIndex(), ref0.getIndex());
+                                       }
+                               }
+               }
+       }
+
+       private static void populateEquivalence(Map<Integer, BitSet> 
equivalence, int i0, int i1) {
+               BitSet bitSet = equivalence.get(i0);
+               if (bitSet == null) {
+                       bitSet = new BitSet();
+                       equivalence.put(i0, bitSet);
+               }
+               bitSet.set(i1);
+       }
+
+       /**
+        * Creates a {@link SqlSplittableAggFunction.Registry}
+        * that is a view of a list.
+        */
+       private static <E> SqlSplittableAggFunction.Registry<E> registry(final 
List<E> list) {
+               return new SqlSplittableAggFunction.Registry<E>() {
+                       public int register(E e) {
+                               int i = list.indexOf(e);
+                               if (i < 0) {
+                                       i = list.size();
+                                       list.add(e);
+                               }
+                               return i;
+                       }
+               };
+       }
+
+       /**
+        * Work space for an input to a join.
+        */
+       private static class Side {
+               final Map<Integer, Integer> split = new HashMap<>();
+               RelNode newInput;
+               boolean aggregate;
+       }
+}
+
+// End FlinkAggregateJoinTransposeRule.java

Reply via email to