gustavodemorais commented on code in PR #26687: URL: https://github.com/apache/flink/pull/26687#discussion_r2153346386
########## flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/JoinToMultiJoinRule.java: ########## @@ -0,0 +1,683 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.logical; + +import org.apache.flink.table.api.TableException; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinInfo; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.LogicalJoin; +import org.apache.calcite.rel.rules.CoreRules; +import org.apache.calcite.rel.rules.FilterMultiJoinMergeRule; +import org.apache.calcite.rel.rules.MultiJoin; +import org.apache.calcite.rel.rules.ProjectMultiJoinMergeRule; +import org.apache.calcite.rel.rules.TransformationRule; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.rex.RexVisitorImpl; +import org.apache.calcite.tools.RelBuilderFactory; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableIntList; +import org.apache.calcite.util.Pair; +import org.immutables.value.Value; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Flink Planner rule to flatten a tree of {@link Join}s into a single {@link MultiJoin} with N + * inputs. + * + * <p>This rule is copied and adjusted from {@link org.apache.calcite.rel.rules.JoinToMultiJoinRule} + * and {@link JoinToMultiJoinForReorderRule}. In this rule, we support a broder sef of left and + * inner joins by rewriting $canCombine() method. The multi join is not expected to be used for + * reordering and will be turned into a {@link + * org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMultiJoin}. + * + * <p>Join conditions are also pulled up from the inputs into the topmost {@link MultiJoin}. + * + * <p>Join information is also stored in the {@link MultiJoin}. Join conditions are stored in arrays + * in the {@link MultiJoin}. This outer join information is associated with the null generating + * input in the outer join. So, in the case of a left outer join between A and B, the information is + * associated with B, not A. + * + * <p>Here are examples of the {@link MultiJoin}s constructed after this rule has been applied on + * following join trees. + * + * <ul> + * <li>A JOIN B → MJ(A, B) + * <li>A JOIN B JOIN C → MJ(A, B, C) + * <li>A LEFT JOIN B → MJ(A, B) + * <li>A RIGHT JOIN B → MJ(A, B) + * <li>A FULL JOIN B → MJ[full](A, B) + * <li>A LEFT JOIN (B JOIN C) → MJ(A, B, C) + * <li>(A JOIN B) LEFT JOIN C → MJ(A, B, C) + * <li>(A LEFT JOIN B) JOIN C → MJ(A, B, C) + * <li>(A LEFT JOIN B) LEFT JOIN C → MJ(A, B, C) + * <li>(A RIGHT JOIN B) RIGHT JOIN C → MJ(MJ(A, B), C) + * <li>(A LEFT JOIN B) RIGHT JOIN C → MJ(MJ(A, B), C) + * <li>(A RIGHT JOIN B) LEFT JOIN C → MJ(MJ(A, B), C) + * <li>A LEFT JOIN (B FULL JOIN C) → MJ(A, MJ[full](B, C)) + * <li>(A LEFT JOIN B) FULL JOIN (C RIGHT JOIN D) → MJ[full](MJ(A, B), MJ(C, D)) + * <li>SEMI JOIN and ANTI JOIN not support now. + * </ul> + * + * <p>The constructor is parameterized to allow any sub-class of {@link Join}, not just {@link + * LogicalJoin}. + * + * @see FilterMultiJoinMergeRule + * @see ProjectMultiJoinMergeRule + * @see CoreRules#JOIN_TO_MULTI_JOIN + */ +@Value.Enclosing +public class JoinToMultiJoinRule extends RelRule<JoinToMultiJoinRule.Config> + implements TransformationRule { + + public static final JoinToMultiJoinRule INSTANCE = JoinToMultiJoinRule.Config.DEFAULT.toRule(); + + /** Creates a JoinToMultiJoinRule. */ + public JoinToMultiJoinRule(Config config) { + super(config); + } + + @Deprecated // to be removed before 2.0 + public JoinToMultiJoinRule(Class<? extends Join> clazz) { + this(Config.DEFAULT.withOperandFor(clazz)); + } + + @Deprecated // to be removed before 2.0 + public JoinToMultiJoinRule( + Class<? extends Join> joinClass, RelBuilderFactory relBuilderFactory) { + this( + Config.DEFAULT + .withRelBuilderFactory(relBuilderFactory) + .as(Config.class) + .withOperandFor(joinClass)); + } + + // ~ Methods ---------------------------------------------------------------- + + @Override + public boolean matches(RelOptRuleCall call) { + final Join origJoin = call.rel(0); + if (origJoin.getJoinType() != JoinRelType.INNER + && origJoin.getJoinType() != JoinRelType.LEFT) { + /* This rule expects only INNER and LEFT joins. Right joins are expected to be + rewritten to left joins by the optimizer with {@link FlinkRightJoinToLeftJoinRule} */ + return false; + } + + return origJoin.getJoinType().projectsRight(); + } + + @Override + public void onMatch(RelOptRuleCall call) { + final Join origJoin = call.rel(0); + final RelNode left = call.rel(1); + final RelNode right = call.rel(2); + + // inputNullGenFieldList records whether the field in originJoin is null generate field. + List<Boolean> inputNullGenFieldList = new ArrayList<>(); + // Build null generate field list. + buildInputNullGenFieldList(left, right, origJoin.getJoinType(), inputNullGenFieldList); + + // Combine the children MultiJoin inputs into an array of inputs for the new MultiJoin. + final List<ImmutableBitSet> projFieldsList = new ArrayList<>(); + final List<int[]> joinFieldRefCountsList = new ArrayList<>(); + final List<RelNode> newInputs = + combineInputs( + origJoin, + left, + right, + projFieldsList, + joinFieldRefCountsList, + inputNullGenFieldList); + + // Combine the join information from the left and right inputs, and include the + // join information from the current join. + final List<Pair<JoinRelType, RexNode>> joinSpecs = new ArrayList<>(); + combineJoinInfo(origJoin, newInputs, left, right, joinSpecs, inputNullGenFieldList); + + // Pull up the join filters from the children MultiJoinRels and combine them with the join + // filter associated with this LogicalJoin to form the join filter for the new MultiJoin. + List<RexNode> newJoinFilters = + combineJoinFilters(origJoin, left, right, inputNullGenFieldList); + + // Add on the join field reference counts for the join condition associated with this + // LogicalJoin. + final com.google.common.collect.ImmutableMap<Integer, ImmutableIntList> + newJoinFieldRefCountsMap = + addOnJoinFieldRefCounts( + newInputs, + origJoin.getRowType().getFieldCount(), + origJoin.getCondition(), + joinFieldRefCountsList); + + List<RexNode> newPostJoinFilters = combinePostJoinFilters(origJoin, left, right); + + final RexBuilder rexBuilder = origJoin.getCluster().getRexBuilder(); + RelNode multiJoin = + new MultiJoin( + origJoin.getCluster(), + newInputs, + RexUtil.composeConjunction(rexBuilder, newJoinFilters), + origJoin.getRowType(), + origJoin.getJoinType() == JoinRelType.FULL, + Pair.right(joinSpecs), + Pair.left(joinSpecs), + projFieldsList, + newJoinFieldRefCountsMap, + RexUtil.composeConjunction(rexBuilder, newPostJoinFilters, true)); + + call.transformTo(multiJoin); + } + + private void buildInputNullGenFieldList( + RelNode left, RelNode right, JoinRelType joinType, List<Boolean> isNullGenFieldList) { + if (joinType == JoinRelType.INNER) { + buildNullGenFieldList(left, isNullGenFieldList); + buildNullGenFieldList(right, isNullGenFieldList); + } else if (joinType == JoinRelType.LEFT) { + // If origin joinType is left means join fields from right side must be null generated + // fields, so we need only judge these join fields in left side and set null generate + // field is true for all right fields. + buildNullGenFieldList(left, isNullGenFieldList); + + for (int i = 0; i < right.getRowType().getFieldCount(); i++) { + isNullGenFieldList.add(true); + } + } else { + // Now, join to multi join rule only support Full outer join, Inner join and Left/Right + // join. + throw new TableException( + "This is a bug. Now, join to multi join rule only support Full outer " + + "join, Inner join and Left/Right join."); + } + } + + private void buildNullGenFieldList(RelNode rel, List<Boolean> isNullGenFieldList) { + MultiJoin multiJoin = rel instanceof MultiJoin ? (MultiJoin) rel : null; + if (multiJoin == null) { + // other operators. + for (int i = 0; i < rel.getRowType().getFieldCount(); i++) { + isNullGenFieldList.add(false); + } + } else { + List<RelNode> inputs = multiJoin.getInputs(); + List<JoinRelType> joinTypes = multiJoin.getJoinTypes(); + for (int i = 0; i < inputs.size() - 1; i++) { + // In list joinTypes, right join node will be added as [RIGHT, INNER], so we need to + // get the joinType from joinTypes in index i. + if (joinTypes.get(i) == JoinRelType.RIGHT) { + buildInputNullGenFieldList( + inputs.get(i), inputs.get(i + 1), joinTypes.get(i), isNullGenFieldList); + } else { + // In list joinTypes, left join node and inner join node will be added as + // [INNER, LEFT] and [INNER, INNER] respectively. so we need to get the joinType + // from joinTypes in index i + 1. + buildInputNullGenFieldList( + inputs.get(i), + inputs.get(i + 1), + joinTypes.get(i + 1), + isNullGenFieldList); + } + } + } + } + + /** + * Combines the inputs into a LogicalJoin into an array of inputs. + * + * @param join original join + * @param left left input into join + * @param right right input into join + * @param projFieldsList returns a list of the new combined projection fields + * @param joinFieldRefCountsList returns a list of the new combined join field reference counts + * @return combined left and right inputs in an array + */ + private List<RelNode> combineInputs( + Join join, + RelNode left, + RelNode right, + List<ImmutableBitSet> projFieldsList, + List<int[]> joinFieldRefCountsList, + List<Boolean> inputNullGenFieldList) { + final List<RelNode> newInputs = new ArrayList<>(); + // Leave the null generating sides of an outer join intact; don't pull up those children + // inputs into the array we're constructing. + JoinInfo joinInfo = join.analyzeCondition(); + ImmutableIntList leftKeys = joinInfo.leftKeys; + ImmutableIntList rightKeys = joinInfo.rightKeys; + + if (canCombine( + left, + leftKeys, + join.getJoinType(), + join.getJoinType().generatesNullsOnLeft(), + true, + inputNullGenFieldList, + 0)) { + final MultiJoin leftMultiJoin = (MultiJoin) left; + for (int i = 0; i < leftMultiJoin.getInputs().size(); i++) { + newInputs.add(leftMultiJoin.getInput(i)); + projFieldsList.add(leftMultiJoin.getProjFields().get(i)); + joinFieldRefCountsList.add( + leftMultiJoin.getJoinFieldRefCountsMap().get(i).toIntArray()); + } + + } else { + newInputs.add(left); + projFieldsList.add(null); + joinFieldRefCountsList.add(new int[left.getRowType().getFieldCount()]); + } + + if (canCombine( + right, + rightKeys, + join.getJoinType(), + join.getJoinType().generatesNullsOnRight(), + false, + inputNullGenFieldList, + left.getRowType().getFieldCount())) { + final MultiJoin rightMultiJoin = (MultiJoin) right; + for (int i = 0; i < rightMultiJoin.getInputs().size(); i++) { + newInputs.add(rightMultiJoin.getInput(i)); + projFieldsList.add(rightMultiJoin.getProjFields().get(i)); + joinFieldRefCountsList.add( + rightMultiJoin.getJoinFieldRefCountsMap().get(i).toIntArray()); + } + } else { + newInputs.add(right); + projFieldsList.add(null); + joinFieldRefCountsList.add(new int[right.getRowType().getFieldCount()]); + } + + return newInputs; + } + + /** + * Combines the join conditions and join types from the left and right join inputs. If the join + * itself is either a left or right outer join, then the join condition corresponding to the + * join is also set in the position corresponding to the null-generating input into the join. + * The join type is also set. + * + * @param joinRel join rel + * @param combinedInputs the combined inputs to the join + * @param left left child of the joinrel + * @param right right child of the joinrel + * @param joinSpecs the list where the join types and conditions will be copied + */ + private void combineJoinInfo( + Join joinRel, + List<RelNode> combinedInputs, + RelNode left, + RelNode right, + List<Pair<JoinRelType, RexNode>> joinSpecs, + List<Boolean> inputNullGenFieldList) { + JoinRelType joinType = joinRel.getJoinType(); + JoinInfo joinInfo = joinRel.analyzeCondition(); + ImmutableIntList leftKeys = joinInfo.leftKeys; + final RexBuilder rexBuilder = joinRel.getCluster().getRexBuilder(); + boolean leftCombined = + canCombine( + left, + leftKeys, + joinType, + joinType.generatesNullsOnLeft(), + true, + inputNullGenFieldList, + 0); + switch (joinType) { + case LEFT: + if (leftCombined) { + copyJoinInfo((MultiJoin) left, joinSpecs, 0, null, null); + } else { + joinSpecs.add(Pair.of(JoinRelType.INNER, rexBuilder.makeLiteral(true))); + } + joinSpecs.add(Pair.of(joinType, joinRel.getCondition())); + break; + case INNER: + if (leftCombined) { + copyJoinInfo((MultiJoin) left, joinSpecs, 0, null, null); + } else { + joinSpecs.add(Pair.of(JoinRelType.INNER, rexBuilder.makeLiteral(true))); + } + + joinSpecs.add(Pair.of(joinType, joinRel.getCondition())); + break; + default: + throw new TableException( + "This is a bug. This rule only supports left and inner joins"); + } + } + + /** + * Copies join data from a source MultiJoin to a new set of arrays. Also adjusts the conditions + * to reflect the new position of an input if that input ends up being shifted to the right. + * + * @param multiJoin the source MultiJoin + * @param destJoinSpecs the list where the join types and conditions will be copied + * @param adjustmentAmount if > 0, the amount the RexInputRefs in the join conditions need to + * be adjusted by + * @param srcFields the source fields that the original join conditions are referencing + * @param destFields the destination fields that the new join conditions + */ + private void copyJoinInfo( + MultiJoin multiJoin, + List<Pair<JoinRelType, RexNode>> destJoinSpecs, + int adjustmentAmount, + @Nullable List<RelDataTypeField> srcFields, + @Nullable List<RelDataTypeField> destFields) { + // getOuterJoinConditions are return all join conditions since that's how we use it + final List<Pair<JoinRelType, RexNode>> srcJoinSpecs = + Pair.zip(multiJoin.getJoinTypes(), multiJoin.getOuterJoinConditions()); + + if (adjustmentAmount == 0) { + destJoinSpecs.addAll(srcJoinSpecs); + } else { + assert srcFields != null; + assert destFields != null; + int nFields = srcFields.size(); + int[] adjustments = new int[nFields]; + for (int idx = 0; idx < nFields; idx++) { + adjustments[idx] = adjustmentAmount; + } + for (Pair<JoinRelType, RexNode> src : srcJoinSpecs) { + destJoinSpecs.add( + Pair.of( + src.left, + src.right == null + ? null + : src.right.accept( + new RelOptUtil.RexInputConverter( + multiJoin.getCluster().getRexBuilder(), + srcFields, + destFields, + adjustments)))); + } + } + } + + /** + * Combines the join filters from the left and right inputs (if they are MultiJoinRels) with the + * join filter in the joinrel into a single AND'd join filter, unless the inputs correspond to + * null generating inputs in an outer join. + * + * @param join Join + * @param left Left input of the join + * @param right Right input of the join + * @return combined join filters AND-ed together + */ + private List<RexNode> combineJoinFilters( + Join join, RelNode left, RelNode right, List<Boolean> inputNullGenFieldList) { + JoinRelType joinType = join.getJoinType(); + JoinInfo joinInfo = join.analyzeCondition(); + ImmutableIntList leftKeys = joinInfo.leftKeys; + + if (joinType == JoinRelType.RIGHT) { + throw new TableException("This is a bug. This rule only supports left and inner joins"); + } + // AND the join condition if this isn't a left join; In those cases, the + // outer join condition is already tracked separately. + final List<RexNode> filters = new ArrayList<>(); + if ((joinType != JoinRelType.LEFT)) { + filters.add(join.getCondition()); + } + if (canCombine( + left, + leftKeys, + joinType, + joinType.generatesNullsOnLeft(), + true, + inputNullGenFieldList, + 0)) { + filters.add(((MultiJoin) left).getJoinFilter()); + } + + return filters; + } + + /** + * Returns whether an input can be merged into a given relational expression without changing + * semantics. + * + * <p>This method should be extended to check for the common join key restriction to support + * multiple multi joins. See <a + * href="https://issues.apache.org/jira/browse/FLINK-37890">FLINK-37890</a>. + * + * @param input input into a join + * @param nullGenerating true if the input is null generating + * @return true if the input can be combined into a parent MultiJoin + */ + private boolean canCombine( + RelNode input, + ImmutableIntList joinKeys, + JoinRelType joinType, + boolean nullGenerating, + boolean isLeft, + List<Boolean> inputNullGenFieldList, + int beginIndex) { + if (input instanceof MultiJoin) { + MultiJoin join = (MultiJoin) input; + if (join.isFullOuterJoin() || nullGenerating) { + return false; + } + + if (joinType == JoinRelType.LEFT) { + if (!isLeft) { + return false; + } else { + for (int joinKey : joinKeys) { + if (inputNullGenFieldList.get(joinKey + beginIndex)) { + return false; + } + } + } + } else if (joinType == JoinRelType.RIGHT) { + if (isLeft) { + return false; + } else { + for (int joinKey : joinKeys) { + if (inputNullGenFieldList.get(joinKey + beginIndex)) { + return false; + } + } + } + } else if (joinType == JoinRelType.INNER) { + for (int joinKey : joinKeys) { + if (inputNullGenFieldList.get(joinKey + beginIndex)) { + return false; + } + } + } else { + return false; + } + return true; + } else { + return false; + } + } + + /** + * Shifts a filter originating from the right child of the LogicalJoin to the right, to reflect + * the filter now being applied on the resulting MultiJoin. + * + * @param joinRel the original LogicalJoin + * @param left the left child of the LogicalJoin + * @param right the right child of the LogicalJoin + * @param rightFilter the filter originating from the right child + * @return the adjusted right filter + */ + private RexNode shiftRightFilter( + Join joinRel, RelNode left, MultiJoin right, RexNode rightFilter) { + if (rightFilter == null) { + return null; + } + + int nFieldsOnLeft = left.getRowType().getFieldList().size(); + int nFieldsOnRight = right.getRowType().getFieldList().size(); + int[] adjustments = new int[nFieldsOnRight]; + for (int i = 0; i < nFieldsOnRight; i++) { + adjustments[i] = nFieldsOnLeft; + } + rightFilter = + rightFilter.accept( + new RelOptUtil.RexInputConverter( + joinRel.getCluster().getRexBuilder(), + right.getRowType().getFieldList(), + joinRel.getRowType().getFieldList(), + adjustments)); + return rightFilter; + } + + /** + * Adds on to the existing join condition reference counts the references from the new join + * condition. + * + * @param multiJoinInputs inputs into the new MultiJoin + * @param nTotalFields total number of fields in the MultiJoin + * @param joinCondition the new join condition + * @param origJoinFieldRefCounts existing join condition reference counts + * @return Map containing the new join condition + */ + private com.google.common.collect.ImmutableMap<Integer, ImmutableIntList> Review Comment: Using now the native map and doing a ImmutableMap.copy when doing new MultiJoin since the calcite class expects it -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org