AMashenkov commented on code in PR #5026:
URL: https://github.com/apache/ignite-3/pull/5026#discussion_r1914697447


##########
modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/logical/IgniteMultiJoinOptimizeBushyRule.java:
##########
@@ -0,0 +1,425 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.internal.sql.engine.rule.logical;
+
+import static java.lang.Integer.bitCount;
+
+import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
+import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
+import java.util.ArrayList;
+import java.util.BitSet;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.IdentityHashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.metadata.RelMetadataQuery;
+import org.apache.calcite.rel.rules.LoptMultiJoin;
+import org.apache.calcite.rel.rules.MultiJoin;
+import org.apache.calcite.rel.rules.TransformationRule;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexPermuteInputsShuttle;
+import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.util.mapping.Mappings;
+import org.apache.calcite.util.mapping.Mappings.TargetMapping;
+import org.immutables.value.Value;
+import org.jetbrains.annotations.Nullable;
+
+/**
+ * Transformation rule used for optimizing multi-join queries using a bushy 
join tree strategy.
+ *
+ * <p>This is an implementation of subset-driven enumeration algorithm. The 
main loop (which actually consists
+ * of two for-loop) enumerates all subset of relations in a way suitable for 
dynamic programming: it guarantees,
+ * that for every set {@code S}, any split of the S will produce subsets which 
have been already processed:
+ * <pre>
+ *     For example, for join of 4 relations it will produce following 
sequence: 
+ *         0011
+ *         0101
+ *         0110
+ *         ...
+ *         1101
+ *         1110
+ *         1111
+ * </pre>
+ *
+ * <p>The inner while-loop enumerates all possible splits of given subset 
{@code S} on disjoint subset
+ * {@code lhs} and {@code rhs} such that {@code lhs ∪ rhs = S}.
+ *
+ * <p>Finally, if the initial set of relations is not connected, the algorithm 
composes cartesian join
+ * from best plans, until all relations are joined.
+ *
+ * <p>Current limitations are as follow:<ol>
+ *     <li>Only INNER joins are supported</li>
+ *     <li>Number of relations to optimize is limited to 20. This is due to 
time and memory complexity of algorithm chosen.</li>
+ *     <li>Disjunctive predicate is not considered as connections.</li>
+ * </ol>
+ */
[email protected]
+public class IgniteMultiJoinOptimizeBushyRule
+        extends RelRule<IgniteMultiJoinOptimizeBushyRule.Config>
+        implements TransformationRule {
+
+    private static final int MAX_JOIN_SIZE = 20;
+
+    /**
+     * Comparator that puts better vertexes first.
+     *
+     * <p>Better vertex is the one that incorporate more relations, or costs 
less.
+     */
+    private static final Comparator<Vertex> VERTEX_COMPARATOR = 
+            Comparator.<Vertex>comparingInt(v -> bitCount(v.id))
+                    .reversed()
+                    .thenComparingDouble(v -> v.cost);
+
+    /** Creates a MultiJoinOptimizeBushyRule. */
+    private IgniteMultiJoinOptimizeBushyRule(Config config) {
+        super(config);
+    }
+
+    @Override
+    public void onMatch(RelOptRuleCall call) {
+        MultiJoin multiJoinRel = call.rel(0);
+
+        int numberOfRelations = multiJoinRel.getInputs().size();
+        if (numberOfRelations > MAX_JOIN_SIZE) {
+            return;
+        }
+
+        // Currently, algorithm below can handle only INNER JOINs
+        if (multiJoinRel.isFullOuterJoin()) {
+            return;
+        }
+
+        for (JoinRelType joinType : multiJoinRel.getJoinTypes()) {
+            if (joinType != JoinRelType.INNER) {
+                return;
+            }
+        }
+
+        LoptMultiJoin multiJoin = new LoptMultiJoin(multiJoinRel);
+
+        RexBuilder rexBuilder = multiJoinRel.getCluster().getRexBuilder();
+        RelBuilder relBuilder = call.builder();
+        RelMetadataQuery mq = call.getMetadataQuery();
+
+        Set<RexNode> unusedConditions = Collections.newSetFromMap(new 
IdentityHashMap<>());
+        unusedConditions.addAll(multiJoin.getJoinFilters());
+
+        Int2ObjectMap<List<Edge>> edges = collectEdges(multiJoin, 
unusedConditions);
+        Int2ObjectMap<Vertex> bestPlan = new Int2ObjectOpenHashMap<>();
+        BitSet connections = new BitSet(1 << numberOfRelations);
+
+        int id = 0b1;
+        int fieldOffset = 0;
+        for (RelNode input : multiJoinRel.getInputs()) {
+            TargetMapping mapping = Mappings.offsetSource(
+                    
Mappings.createIdentity(input.getRowType().getFieldCount()),
+                    fieldOffset,
+                    multiJoin.getNumTotalFields()
+            );
+
+            bestPlan.put(id, new Vertex(id, mq.getRowCount(input), input, 
mapping));
+            connections.set(id);
+
+            id <<= 1;
+            fieldOffset += input.getRowType().getFieldCount();
+        }
+
+        Vertex bestSoFar = null;
+        for (int k = 2; k <= numberOfRelations; k++) {
+            for (int s = (1 << (k - 1)) + 1; s < 1 << k; s++) {
+                int lhs = Integer.lowestOneBit(s);
+                while (lhs < (s / 2) + 1) {
+                    int rhs = s - lhs;
+
+                    List<Edge> edges0;
+                    if (connections.get(lhs) && connections.get(rhs)) {
+                        edges0 = findEdges(lhs, rhs, edges);
+                    } else {
+                        edges0 = List.of();
+                    }
+
+                    if (!edges0.isEmpty()) {
+                        connections.set(s);
+
+                        Vertex planLhs = bestPlan.get(lhs);
+                        Vertex planRhs = bestPlan.get(rhs);
+
+                        Vertex newPlan = createJoin(planLhs, planRhs, edges0, 
mq, relBuilder, rexBuilder);
+                        Vertex currentBest = bestPlan.get(s);
+                        if (currentBest == null || currentBest.cost > 
newPlan.cost) {
+                            bestPlan.put(s, newPlan);
+
+                            bestSoFar = chooseBest(bestSoFar, newPlan);
+                        }
+
+                        aggregateEdges(edges, lhs, rhs);
+                    }
+
+                    lhs = s & (lhs - s);
+                }
+            }
+        }
+
+        Vertex best;
+        if (bestSoFar == null || bestSoFar.id != (1 << numberOfRelations) - 1) 
{
+            best = composeCartesianJoin(bestPlan, edges, bestSoFar, mq, 
relBuilder, rexBuilder);
+        } else {
+            best = bestSoFar;
+        }
+
+        RelNode result = relBuilder
+                .push(best.rel)
+                .filter(RexUtil.composeConjunction(rexBuilder, 
unusedConditions)
+                        .accept(new RexPermuteInputsShuttle(best.mapping, 
best.rel)))
+                .project(relBuilder.fields(best.mapping))
+                .build();
+
+        call.transformTo(result);
+    }
+
+    private static void aggregateEdges(Int2ObjectMap<List<Edge>> edges, int 
lhs, int rhs) {
+        int id = lhs | rhs;
+        if (!edges.containsKey(id)) {
+            Set<Edge> union = Collections.newSetFromMap(new 
IdentityHashMap<>());
+
+            union.addAll(edges.getOrDefault(lhs, List.of()));
+            union.addAll(edges.getOrDefault(rhs, List.of()));
+
+            edges.put(id, List.copyOf(union));
+        }
+    }
+
+    private static Vertex composeCartesianJoin(
+            Int2ObjectMap<Vertex> bestPlan,
+            Int2ObjectMap<List<Edge>> edges,
+            @Nullable Vertex bestSoFar,
+            RelMetadataQuery mq,
+            RelBuilder relBuilder,
+            RexBuilder rexBuilder
+    ) {
+        List<Vertex> options;
+
+        if (bestSoFar != null) {
+            options = new ArrayList<>();
+
+            for (Vertex option : bestPlan.values()) {
+                if ((option.id & bestSoFar.id) == 0) {
+                    options.add(option);
+                }
+            }
+        } else {
+            options = new ArrayList<>(bestPlan.values());
+        }
+
+        options.sort(VERTEX_COMPARATOR);
+
+        Iterator<Vertex> it = options.iterator();
+
+        if (bestSoFar == null) {
+            bestSoFar = it.next();
+        }
+
+        while (it.hasNext()) {
+            Vertex input = it.next();
+
+            if ((bestSoFar.id & input.id) != 0) {
+                continue;
+            }
+
+            List<Edge> edges0 = findEdges(bestSoFar.id, input.id, edges);
+
+            aggregateEdges(edges, bestSoFar.id, input.id);
+
+            bestSoFar = createJoin(bestSoFar, input, edges0, mq, relBuilder, 
rexBuilder);
+        }
+
+        return bestSoFar;
+    }
+
+    private static Vertex chooseBest(@Nullable Vertex currentBest, Vertex 
candidate) {
+        if (currentBest == null) {
+            return candidate;
+        }
+
+        if (VERTEX_COMPARATOR.compare(currentBest, candidate) > 0) {
+            return candidate;
+        }
+
+        return currentBest;
+    }
+
+    private static Int2ObjectMap<List<Edge>> collectEdges(LoptMultiJoin 
multiJoin, Set<RexNode> unusedConditions) {
+        Int2ObjectMap<List<Edge>> edges = new Int2ObjectOpenHashMap<>();
+
+        for (RexNode condition : multiJoin.getJoinFilters()) {
+            int[] inputRefs = 
multiJoin.getFactorsRefByJoinFilter(condition).toArray();
+
+            if (inputRefs.length < 2 || condition.isA(SqlKind.OR)) {
+                continue;
+            }
+
+            unusedConditions.remove(condition);

Review Comment:
   We can fill `unusedConditions` inside the `if-block`  instead of removing 
them.
   This will avoid one iteration over join filter, lookups into hasheset, and 
Set can be replaced with a list.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to