This is an automated email from the ASF dual-hosted git repository. jark pushed a commit to branch release-1.9 in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.9 by this push: new 47d0fed [FLINK-13545] [table-planner-blink] JoinToMultiJoinRule should not match SEMI/ANTI LogicalJoin 47d0fed is described below commit 47d0fed51999ccea847b3fa6b6645a6adad54843 Author: godfreyhe <godfre...@163.com> AuthorDate: Fri Aug 2 11:50:28 2019 +0800 [FLINK-13545] [table-planner-blink] JoinToMultiJoinRule should not match SEMI/ANTI LogicalJoin This closes #9329 --- .../rules/logical/FlinkJoinToMultiJoinRule.java | 594 +++++++++++++++++++++ .../planner/plan/rules/FlinkBatchRuleSets.scala | 2 +- .../planner/plan/rules/FlinkStreamRuleSets.scala | 2 +- .../rules/logical/FlinkJoinToMultiJoinRuleTest.xml | 81 +++ .../logical/FlinkJoinToMultiJoinRuleTest.scala | 72 +++ 5 files changed, 749 insertions(+), 2 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRule.java new file mode 100644 index 0000000..0d4a954 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRule.java @@ -0,0 +1,594 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.logical; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.logical.LogicalJoin; +import org.apache.calcite.rel.rules.MultiJoin; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.rex.RexVisitorImpl; +import org.apache.calcite.tools.RelBuilderFactory; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableIntList; +import org.apache.calcite.util.Pair; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * This rule is copied from Calcite's {@link org.apache.calcite.rel.rules.JoinToMultiJoinRule}. + * This file should be removed while upgrading Calcite version to 1.21. [CALCITE-3225] + * Modification: + * - Does not match SEMI/ANTI join. lines changed (142-145) + * - lines changed (440-451) + */ + +/** + * Planner rule to flatten a tree of + * {@link org.apache.calcite.rel.logical.LogicalJoin}s + * into a single {@link MultiJoin} with N inputs. + * + * <p>An input is not flattened if + * the input is a null generating input in an outer join, i.e., either input in + * a full outer join, the right hand side of a left outer join, or the left hand + * side of a right outer join. + * + * <p>Join conditions are also pulled up from the inputs into the topmost + * {@link MultiJoin}, + * unless the input corresponds to a null generating input in an outer join, + * + * <p>Outer join information is also stored in the {@link MultiJoin}. A + * boolean flag indicates if the join is a full outer join, and in the case of + * left and right outer joins, the join type and outer join conditions are + * stored in arrays in the {@link MultiJoin}. This outer join information is + * associated with the null generating input in the outer join. So, in the case + * of a a left outer join between A and B, the information is associated with B, + * not A. + * + * <p>Here are examples of the {@link MultiJoin}s constructed after this rule + * has been applied on following join trees. + * + * <ul> + * <li>A JOIN B → MJ(A, B) + * + * <li>A JOIN B JOIN C → MJ(A, B, C) + * + * <li>A LEFT JOIN B → MJ(A, B), left outer join on input#1 + * + * <li>A RIGHT JOIN B → MJ(A, B), right outer join on input#0 + * + * <li>A FULL JOIN B → MJ[full](A, B) + * + * <li>A LEFT JOIN (B JOIN C) → MJ(A, MJ(B, C))), left outer join on + * input#1 in the outermost MultiJoin + * + * <li>(A JOIN B) LEFT JOIN C → MJ(A, B, C), left outer join on input#2 + * + * <li>(A LEFT JOIN B) JOIN C → MJ(MJ(A, B), C), left outer join on input#1 + * of the inner MultiJoin TODO + * + * <li>A LEFT JOIN (B FULL JOIN C) → MJ(A, MJ[full](B, C)), left outer join + * on input#1 in the outermost MultiJoin + * + * <li>(A LEFT JOIN B) FULL JOIN (C RIGHT JOIN D) → + * MJ[full](MJ(A, B), MJ(C, D)), left outer join on input #1 in the first + * inner MultiJoin and right outer join on input#0 in the second inner + * MultiJoin + * </ul> + * + * <p>The constructor is parameterized to allow any sub-class of + * {@link org.apache.calcite.rel.core.Join}, not just + * {@link org.apache.calcite.rel.logical.LogicalJoin}.</p> + * + * @see org.apache.calcite.rel.rules.FilterMultiJoinMergeRule + * @see org.apache.calcite.rel.rules.ProjectMultiJoinMergeRule + */ +public class FlinkJoinToMultiJoinRule extends RelOptRule { + public static final FlinkJoinToMultiJoinRule INSTANCE = + new FlinkJoinToMultiJoinRule(LogicalJoin.class, RelFactories.LOGICAL_BUILDER); + + //~ Constructors ----------------------------------------------------------- + + @Deprecated // to be removed before 2.0 + public FlinkJoinToMultiJoinRule(Class<? extends Join> clazz) { + this(clazz, RelFactories.LOGICAL_BUILDER); + } + + /** + * Creates a FlinkJoinToMultiJoinRule. + */ + public FlinkJoinToMultiJoinRule(Class<? extends Join> clazz, + RelBuilderFactory relBuilderFactory) { + super( + operand(clazz, + operand(RelNode.class, any()), + operand(RelNode.class, any())), + relBuilderFactory, null); + } + + //~ Methods ---------------------------------------------------------------- + + @Override + public boolean matches(RelOptRuleCall call) { + final Join origJoin = call.rel(0); + return origJoin.getJoinType() != JoinRelType.SEMI && origJoin.getJoinType() != JoinRelType.ANTI; + } + + public void onMatch(RelOptRuleCall call) { + final Join origJoin = call.rel(0); + final RelNode left = call.rel(1); + final RelNode right = call.rel(2); + + // combine the children MultiJoin inputs into an array of inputs + // for the new MultiJoin + final List<ImmutableBitSet> projFieldsList = new ArrayList<>(); + final List<int[]> joinFieldRefCountsList = new ArrayList<>(); + final List<RelNode> newInputs = + combineInputs( + origJoin, + left, + right, + projFieldsList, + joinFieldRefCountsList); + + // combine the outer join information from the left and right + // inputs, and include the outer join information from the current + // join, if it's a left/right outer join + final List<Pair<JoinRelType, RexNode>> joinSpecs = new ArrayList<>(); + combineOuterJoins( + origJoin, + newInputs, + left, + right, + joinSpecs); + + // pull up the join filters from the children MultiJoinRels and + // combine them with the join filter associated with this LogicalJoin to + // form the join filter for the new MultiJoin + List<RexNode> newJoinFilters = combineJoinFilters(origJoin, left, right); + + // add on the join field reference counts for the join condition + // associated with this LogicalJoin + final com.google.common.collect.ImmutableMap<Integer, ImmutableIntList> newJoinFieldRefCountsMap = + addOnJoinFieldRefCounts(newInputs, + origJoin.getRowType().getFieldCount(), + origJoin.getCondition(), + joinFieldRefCountsList); + + List<RexNode> newPostJoinFilters = + combinePostJoinFilters(origJoin, left, right); + + final RexBuilder rexBuilder = origJoin.getCluster().getRexBuilder(); + RelNode multiJoin = + new MultiJoin( + origJoin.getCluster(), + newInputs, + RexUtil.composeConjunction(rexBuilder, newJoinFilters), + origJoin.getRowType(), + origJoin.getJoinType() == JoinRelType.FULL, + Pair.right(joinSpecs), + Pair.left(joinSpecs), + projFieldsList, + newJoinFieldRefCountsMap, + RexUtil.composeConjunction(rexBuilder, newPostJoinFilters, true)); + + call.transformTo(multiJoin); + } + + /** + * Combines the inputs into a LogicalJoin into an array of inputs. + * + * @param join original join + * @param left left input into join + * @param right right input into join + * @param projFieldsList returns a list of the new combined projection + * fields + * @param joinFieldRefCountsList returns a list of the new combined join + * field reference counts + * @return combined left and right inputs in an array + */ + private List<RelNode> combineInputs( + Join join, + RelNode left, + RelNode right, + List<ImmutableBitSet> projFieldsList, + List<int[]> joinFieldRefCountsList) { + final List<RelNode> newInputs = new ArrayList<>(); + + // leave the null generating sides of an outer join intact; don't + // pull up those children inputs into the array we're constructing + if (canCombine(left, join.getJoinType().generatesNullsOnLeft())) { + final MultiJoin leftMultiJoin = (MultiJoin) left; + for (int i = 0; i < left.getInputs().size(); i++) { + newInputs.add(leftMultiJoin.getInput(i)); + projFieldsList.add(leftMultiJoin.getProjFields().get(i)); + joinFieldRefCountsList.add( + leftMultiJoin.getJoinFieldRefCountsMap().get(i).toIntArray()); + } + } else { + newInputs.add(left); + projFieldsList.add(null); + joinFieldRefCountsList.add( + new int[left.getRowType().getFieldCount()]); + } + + if (canCombine(right, join.getJoinType().generatesNullsOnRight())) { + final MultiJoin rightMultiJoin = (MultiJoin) right; + for (int i = 0; i < right.getInputs().size(); i++) { + newInputs.add(rightMultiJoin.getInput(i)); + projFieldsList.add( + rightMultiJoin.getProjFields().get(i)); + joinFieldRefCountsList.add( + rightMultiJoin.getJoinFieldRefCountsMap().get(i).toIntArray()); + } + } else { + newInputs.add(right); + projFieldsList.add(null); + joinFieldRefCountsList.add( + new int[right.getRowType().getFieldCount()]); + } + + return newInputs; + } + + /** + * Combines the outer join conditions and join types from the left and right + * join inputs. If the join itself is either a left or right outer join, + * then the join condition corresponding to the join is also set in the + * position corresponding to the null-generating input into the join. The + * join type is also set. + * + * @param joinRel join rel + * @param combinedInputs the combined inputs to the join + * @param left left child of the joinrel + * @param right right child of the joinrel + * @param joinSpecs the list where the join types and conditions will be + * copied + */ + private void combineOuterJoins( + Join joinRel, + List<RelNode> combinedInputs, + RelNode left, + RelNode right, + List<Pair<JoinRelType, RexNode>> joinSpecs) { + JoinRelType joinType = joinRel.getJoinType(); + boolean leftCombined = + canCombine(left, joinType.generatesNullsOnLeft()); + boolean rightCombined = + canCombine(right, joinType.generatesNullsOnRight()); + switch (joinType) { + case LEFT: + if (leftCombined) { + copyOuterJoinInfo( + (MultiJoin) left, + joinSpecs, + 0, + null, + null); + } else { + joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); + } + joinSpecs.add(Pair.of(joinType, joinRel.getCondition())); + break; + case RIGHT: + joinSpecs.add(Pair.of(joinType, joinRel.getCondition())); + if (rightCombined) { + copyOuterJoinInfo( + (MultiJoin) right, + joinSpecs, + left.getRowType().getFieldCount(), + right.getRowType().getFieldList(), + joinRel.getRowType().getFieldList()); + } else { + joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); + } + break; + default: + if (leftCombined) { + copyOuterJoinInfo( + (MultiJoin) left, + joinSpecs, + 0, + null, + null); + } else { + joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); + } + if (rightCombined) { + copyOuterJoinInfo( + (MultiJoin) right, + joinSpecs, + left.getRowType().getFieldCount(), + right.getRowType().getFieldList(), + joinRel.getRowType().getFieldList()); + } else { + joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); + } + } + } + + /** + * Copies outer join data from a source MultiJoin to a new set of arrays. + * Also adjusts the conditions to reflect the new position of an input if + * that input ends up being shifted to the right. + * + * @param multiJoin the source MultiJoin + * @param destJoinSpecs the list where the join types and conditions will + * be copied + * @param adjustmentAmount if > 0, the amount the RexInputRefs in the join + * conditions need to be adjusted by + * @param srcFields the source fields that the original join conditions + * are referencing + * @param destFields the destination fields that the new join conditions + */ + private void copyOuterJoinInfo( + MultiJoin multiJoin, + List<Pair<JoinRelType, RexNode>> destJoinSpecs, + int adjustmentAmount, + List<RelDataTypeField> srcFields, + List<RelDataTypeField> destFields) { + final List<Pair<JoinRelType, RexNode>> srcJoinSpecs = + Pair.zip( + multiJoin.getJoinTypes(), + multiJoin.getOuterJoinConditions()); + + if (adjustmentAmount == 0) { + destJoinSpecs.addAll(srcJoinSpecs); + } else { + assert srcFields != null; + assert destFields != null; + int nFields = srcFields.size(); + int[] adjustments = new int[nFields]; + for (int idx = 0; idx < nFields; idx++) { + adjustments[idx] = adjustmentAmount; + } + for (Pair<JoinRelType, RexNode> src + : srcJoinSpecs) { + destJoinSpecs.add( + Pair.of( + src.left, + src.right == null + ? null + : src.right.accept( + new RelOptUtil.RexInputConverter( + multiJoin.getCluster().getRexBuilder(), + srcFields, destFields, adjustments)))); + } + } + } + + /** + * Combines the join filters from the left and right inputs (if they are + * MultiJoinRels) with the join filter in the joinrel into a single AND'd + * join filter, unless the inputs correspond to null generating inputs in an + * outer join. + * + * @param joinRel join rel + * @param left left child of the join + * @param right right child of the join + * @return combined join filters AND-ed together + */ + private List<RexNode> combineJoinFilters( + Join joinRel, + RelNode left, + RelNode right) { + JoinRelType joinType = joinRel.getJoinType(); + + // AND the join condition if this isn't a left or right outer join; + // in those cases, the outer join condition is already tracked + // separately + final List<RexNode> filters = new ArrayList<>(); + if ((joinType != JoinRelType.LEFT) && (joinType != JoinRelType.RIGHT)) { + filters.add(joinRel.getCondition()); + } + if (canCombine(left, joinType.generatesNullsOnLeft())) { + filters.add(((MultiJoin) left).getJoinFilter()); + } + // Need to adjust the RexInputs of the right child, since + // those need to shift over to the right + if (canCombine(right, joinType.generatesNullsOnRight())) { + MultiJoin multiJoin = (MultiJoin) right; + filters.add( + shiftRightFilter(joinRel, left, multiJoin, + multiJoin.getJoinFilter())); + } + + return filters; + } + + /** + * Returns whether an input can be merged into a given relational expression + * without changing semantics. + * + * @param input input into a join + * @param nullGenerating true if the input is null generating + * @return true if the input can be combined into a parent MultiJoin + */ + private boolean canCombine(RelNode input, boolean nullGenerating) { + return input instanceof MultiJoin + && !((MultiJoin) input).isFullOuterJoin() + && !(containsOuter((MultiJoin) input)) + && !nullGenerating; + } + + private boolean containsOuter(MultiJoin multiJoin) { + for (JoinRelType joinType : multiJoin.getJoinTypes()) { + if (joinType.isOuterJoin()) { + return true; + } + } + return false; + } + + /** + * Shifts a filter originating from the right child of the LogicalJoin to the + * right, to reflect the filter now being applied on the resulting + * MultiJoin. + * + * @param joinRel the original LogicalJoin + * @param left the left child of the LogicalJoin + * @param right the right child of the LogicalJoin + * @param rightFilter the filter originating from the right child + * @return the adjusted right filter + */ + private RexNode shiftRightFilter( + Join joinRel, + RelNode left, + MultiJoin right, + RexNode rightFilter) { + if (rightFilter == null) { + return null; + } + + int nFieldsOnLeft = left.getRowType().getFieldList().size(); + int nFieldsOnRight = right.getRowType().getFieldList().size(); + int[] adjustments = new int[nFieldsOnRight]; + for (int i = 0; i < nFieldsOnRight; i++) { + adjustments[i] = nFieldsOnLeft; + } + rightFilter = + rightFilter.accept( + new RelOptUtil.RexInputConverter( + joinRel.getCluster().getRexBuilder(), + right.getRowType().getFieldList(), + joinRel.getRowType().getFieldList(), + adjustments)); + return rightFilter; + } + + /** + * Adds on to the existing join condition reference counts the references + * from the new join condition. + * + * @param multiJoinInputs inputs into the new MultiJoin + * @param nTotalFields total number of fields in the MultiJoin + * @param joinCondition the new join condition + * @param origJoinFieldRefCounts existing join condition reference counts + * + * @return Map containing the new join condition + */ + private com.google.common.collect.ImmutableMap<Integer, ImmutableIntList> addOnJoinFieldRefCounts( + List<RelNode> multiJoinInputs, + int nTotalFields, + RexNode joinCondition, + List<int[]> origJoinFieldRefCounts) { + // count the input references in the join condition + int[] joinCondRefCounts = new int[nTotalFields]; + joinCondition.accept(new FlinkJoinToMultiJoinRule.InputReferenceCounter(joinCondRefCounts)); + + // first, make a copy of the ref counters + final Map<Integer, int[]> refCountsMap = new HashMap<>(); + int nInputs = multiJoinInputs.size(); + int currInput = 0; + for (int[] origRefCounts : origJoinFieldRefCounts) { + refCountsMap.put( + currInput, + origRefCounts.clone()); + currInput++; + } + + // add on to the counts for each input into the MultiJoin the + // reference counts computed for the current join condition + currInput = -1; + int startField = 0; + int nFields = 0; + for (int i = 0; i < nTotalFields; i++) { + if (joinCondRefCounts[i] == 0) { + continue; + } + while (i >= (startField + nFields)) { + startField += nFields; + currInput++; + assert currInput < nInputs; + nFields = + multiJoinInputs.get(currInput).getRowType().getFieldCount(); + } + int[] refCounts = refCountsMap.get(currInput); + refCounts[i - startField] += joinCondRefCounts[i]; + } + + final com.google.common.collect.ImmutableMap.Builder<Integer, ImmutableIntList> builder = + com.google.common.collect.ImmutableMap.builder(); + for (Map.Entry<Integer, int[]> entry : refCountsMap.entrySet()) { + builder.put(entry.getKey(), ImmutableIntList.of(entry.getValue())); + } + return builder.build(); + } + + /** + * Combines the post-join filters from the left and right inputs (if they + * are MultiJoinRels) into a single AND'd filter. + * + * @param joinRel the original LogicalJoin + * @param left left child of the LogicalJoin + * @param right right child of the LogicalJoin + * @return combined post-join filters AND'd together + */ + private List<RexNode> combinePostJoinFilters( + Join joinRel, + RelNode left, + RelNode right) { + final List<RexNode> filters = new ArrayList<>(); + if (right instanceof MultiJoin) { + final MultiJoin multiRight = (MultiJoin) right; + filters.add( + shiftRightFilter(joinRel, left, multiRight, + multiRight.getPostJoinFilter())); + } + + if (left instanceof MultiJoin) { + filters.add(((MultiJoin) left).getPostJoinFilter()); + } + + return filters; + } + + //~ Inner Classes ---------------------------------------------------------- + + /** + * Visitor that keeps a reference count of the inputs used by an expression. + */ + private class InputReferenceCounter extends RexVisitorImpl<Void> { + private final int[] refCounts; + + InputReferenceCounter(int[] refCounts) { + super(true); + this.refCounts = refCounts; + } + + public Void visitInputRef(RexInputRef inputRef) { + refCounts[inputRef.getIndex()]++; + return null; + } + } +} + +// End FlinkJoinToMultiJoinRule.java diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala index 0f48950..076242e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala @@ -223,7 +223,7 @@ object FlinkBatchRuleSets { val JOIN_REORDER_PERPARE_RULES: RuleSet = RuleSets.ofList( // merge join to MultiJoin - JoinToMultiJoinRule.INSTANCE, + FlinkJoinToMultiJoinRule.INSTANCE, // merge project to MultiJoin ProjectMultiJoinMergeRule.INSTANCE, // merge filter to MultiJoin diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala index cb371ab..ce6b6e5 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala @@ -210,7 +210,7 @@ object FlinkStreamRuleSets { // merge filter to MultiJoin FilterMultiJoinMergeRule.INSTANCE, // merge join to MultiJoin - JoinToMultiJoinRule.INSTANCE + FlinkJoinToMultiJoinRule.INSTANCE ) val JOIN_REORDER_RULES: RuleSet = RuleSets.ofList( diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRuleTest.xml new file mode 100644 index 0000000..92a0a62 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRuleTest.xml @@ -0,0 +1,81 @@ +<?xml version="1.0" ?> +<!-- +Licensed to the Apache Software Foundation (ASF) under one or more +contributor license agreements. See the NOTICE file distributed with +this work for additional information regarding copyright ownership. +The ASF licenses this file to you under the Apache License, Version 2.0 +(the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--> +<Root> + <TestCase name="testDoesNotMatchAntiJoin"> + <Resource name="sql"> + <![CDATA[ +SELECT * FROM (SELECT * FROM T1 JOIN T2 ON a = c) t +WHERE NOT EXISTS (SELECT e FROM T3 WHERE a = e) + ]]> + </Resource> + <Resource name="planBefore"> + <![CDATA[ +LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3]) ++- LogicalFilter(condition=[NOT(EXISTS({ +LogicalFilter(condition=[=($cor0.a, $0)]) + LogicalTableScan(table=[[default_catalog, default_database, T3, source: [TestTableSource(e, f)]]]) +}))], variablesSet=[[$cor0]]) + +- LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3]) + +- LogicalJoin(condition=[=($0, $2)], joinType=[inner]) + :- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b)]]]) + +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(c, d)]]]) +]]> + </Resource> + <Resource name="planAfter"> + <![CDATA[ +LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3]) ++- LogicalJoin(condition=[=($0, $4)], joinType=[anti]) + :- MultiJoin(joinFilter=[=($0, $2)], isFullOuterJoin=[false], joinTypes=[[INNER, INNER]], outerJoinConditions=[[NULL, NULL]], projFields=[[{0, 1}, {0, 1}]]) + : :- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b)]]]) + : +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(c, d)]]]) + +- LogicalProject(e=[$0]) + +- LogicalFilter(condition=[true]) + +- LogicalTableScan(table=[[default_catalog, default_database, T3, source: [TestTableSource(e, f)]]]) +]]> + </Resource> + </TestCase> + <TestCase name="testDoesNotMatchSemiJoin"> + <Resource name="sql"> + <![CDATA[SELECT * FROM (SELECT * FROM T1 JOIN T2 ON a = c) t WHERE a IN (SELECT e FROM T3)]]> + </Resource> + <Resource name="planBefore"> + <![CDATA[ +LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3]) ++- LogicalFilter(condition=[IN($0, { +LogicalProject(e=[$0]) + LogicalTableScan(table=[[default_catalog, default_database, T3, source: [TestTableSource(e, f)]]]) +})]) + +- LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3]) + +- LogicalJoin(condition=[=($0, $2)], joinType=[inner]) + :- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b)]]]) + +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(c, d)]]]) +]]> + </Resource> + <Resource name="planAfter"> + <![CDATA[ +LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3]) ++- LogicalJoin(condition=[=($0, $4)], joinType=[semi]) + :- MultiJoin(joinFilter=[=($0, $2)], isFullOuterJoin=[false], joinTypes=[[INNER, INNER]], outerJoinConditions=[[NULL, NULL]], projFields=[[{0, 1}, {0, 1}]]) + : :- LogicalTableScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a, b)]]]) + : +- LogicalTableScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(c, d)]]]) + +- LogicalProject(e=[$0]) + +- LogicalTableScan(table=[[default_catalog, default_database, T3, source: [TestTableSource(e, f)]]]) +]]> + </Resource> + </TestCase> +</Root> diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRuleTest.scala new file mode 100644 index 0000000..4364704 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/FlinkJoinToMultiJoinRuleTest.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.planner.plan.rules.logical + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.planner.plan.optimize.program.{FlinkBatchProgram, FlinkHepRuleSetProgramBuilder, HEP_RULES_EXECUTION_TYPE} +import org.apache.flink.table.planner.utils.{TableConfigUtils, TableTestBase} + +import org.apache.calcite.plan.hep.HepMatchOrder +import org.apache.calcite.rel.rules.ProjectMultiJoinMergeRule +import org.apache.calcite.tools.RuleSets +import org.junit.{Before, Test} + +/** + * Test for [[FlinkJoinToMultiJoinRule]]. + */ +class FlinkJoinToMultiJoinRuleTest extends TableTestBase { + private val util = batchTestUtil() + + @Before + def setup(): Unit = { + util.buildBatchProgram(FlinkBatchProgram.DEFAULT_REWRITE) + val calciteConfig = TableConfigUtils.getCalciteConfig(util.tableEnv.getConfig) + calciteConfig.getBatchProgram.get.addLast( + "rules", + FlinkHepRuleSetProgramBuilder.newBuilder + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_COLLECTION) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(RuleSets.ofList( + FlinkJoinToMultiJoinRule.INSTANCE, + ProjectMultiJoinMergeRule.INSTANCE)) + .build() + ) + + util.addTableSource[(Int, Long)]("T1", 'a, 'b) + util.addTableSource[(Int, Long)]("T2", 'c, 'd) + util.addTableSource[(Int, Long)]("T3", 'e, 'f) + } + + @Test + def testDoesNotMatchSemiJoin(): Unit = { + val sqlQuery = + "SELECT * FROM (SELECT * FROM T1 JOIN T2 ON a = c) t WHERE a IN (SELECT e FROM T3)" + util.verifyPlan(sqlQuery) + } + + @Test + def testDoesNotMatchAntiJoin(): Unit = { + val sqlQuery = + """ + |SELECT * FROM (SELECT * FROM T1 JOIN T2 ON a = c) t + |WHERE NOT EXISTS (SELECT e FROM T3 WHERE a = e) + """.stripMargin + util.verifyPlan(sqlQuery) + } +}