mihaibudiu commented on code in PR #4619: URL: https://github.com/apache/calcite/pull/4619#discussion_r2504696529
########## core/src/main/java/org/apache/calcite/sql2rel/TopDownGeneralDecorrelator.java: ########## @@ -0,0 +1,785 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.sql2rel; + +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.Strong; +import org.apache.calcite.plan.hep.HepPlanner; +import org.apache.calcite.plan.hep.HepProgram; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Correlate; +import org.apache.calcite.rel.core.CorrelationId; +import org.apache.calcite.rel.core.Filter; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.SetOp; +import org.apache.calcite.rel.core.Sort; +import org.apache.calcite.rel.rules.CoreRules; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexCorrelVariable; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.rex.RexWindow; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.fun.SqlCountAggFunction; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql2rel.RelDecorrelator.CorDef; +import org.apache.calcite.sql2rel.RelDecorrelator.Frame; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.Litmus; +import org.apache.calcite.util.ReflectUtil; +import org.apache.calcite.util.ReflectiveVisitor; +import org.apache.calcite.util.mapping.Mappings; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.NavigableMap; +import java.util.NavigableSet; +import java.util.Set; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.stream.IntStream; + +import static java.util.Objects.requireNonNull; + +/** + * A top‑down, generic decorrelation algorithm that can handle deep nestings of correlated + * subqueries and that generalizes to complex query constructs. More details are in paper: + * <a href="https://dl.gi.de/items/b9df4765-d1b0-4267-a77c-4ce4ab0ee62d"> + * Improving Unnesting of Complex Queries</a> + */ +public class TopDownGeneralDecorrelator implements ReflectiveVisitor { + + private final RelBuilder builder; + + private final NavigableSet<CorDef> corDefs; + + private final Map<RelNode, Boolean> hasCorrelatedExpressions; + + private final Map<RelNode, UnnestInfo> mapRelToUnnestInfo; + + private final boolean hasParent; + + private boolean emptyOutputSensitive; + + private RelNode dedupFreeVarsNode; + + @SuppressWarnings("method.invocation.invalid") + private final ReflectUtil.MethodDispatcher<RelNode> dispatcher = + ReflectUtil.createMethodDispatcher( + RelNode.class, getVisitor(), "unnestInternal", RelNode.class); + + @SuppressWarnings("initialization.fields.uninitialized") + public TopDownGeneralDecorrelator( + RelBuilder builder, + boolean hasParent, + boolean emptyOutputSensitive, + @Nullable Set<CorDef> parentCorDefs, + @Nullable Map<RelNode, Boolean> parentHasCorrelatedExpressions, + @Nullable Map<RelNode, UnnestInfo> parentMapRelToUnnestInfo) { + this.builder = builder; + this.hasParent = hasParent; + this.emptyOutputSensitive = emptyOutputSensitive; + this.corDefs = new TreeSet<>(); + if (parentCorDefs != null) { + this.corDefs.addAll(parentCorDefs); + } + this.hasCorrelatedExpressions = parentHasCorrelatedExpressions == null + ? new HashMap<>() + : parentHasCorrelatedExpressions; + this.mapRelToUnnestInfo = parentMapRelToUnnestInfo == null + ? new HashMap<>() + : parentMapRelToUnnestInfo; + } + + private static TopDownGeneralDecorrelator createEmptyDecorrelator(RelBuilder builder) { + return new TopDownGeneralDecorrelator(builder, false, false, null, null, null); + } + + private TopDownGeneralDecorrelator createSubDecorrelator() { + TopDownGeneralDecorrelator subDecorrelator = + new TopDownGeneralDecorrelator( + builder, + true, + emptyOutputSensitive, + corDefs, + hasCorrelatedExpressions, + mapRelToUnnestInfo); + subDecorrelator.dedupFreeVarsNode = this.dedupFreeVarsNode; + return subDecorrelator; + } + + public static RelNode decorrelateQuery(RelNode rel, RelBuilder builder) { + HepProgram program = HepProgram.builder() + .addRuleInstance(CoreRules.FILTER_PROJECT_TRANSPOSE) + .addRuleInstance(CoreRules.FILTER_INTO_JOIN) + .addRuleInstance(CoreRules.FILTER_CORRELATE) + .build(); + HepPlanner prePlanner = new HepPlanner(program); + prePlanner.setRoot(rel); + RelNode preparedRel = prePlanner.findBestExp(); + + TopDownGeneralDecorrelator decorrelator = createEmptyDecorrelator(builder); + RelNode decorrelateNode = decorrelator.correlateElimination(preparedRel); + + HepProgram postProgram = HepProgram.builder() + .addRuleInstance(CoreRules.FILTER_PROJECT_TRANSPOSE) + .addRuleInstance(CoreRules.FILTER_INTO_JOIN) + .addRuleInstance(CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE) + .addRuleInstance(CoreRules.PROJECT_MERGE) + .addRuleInstance(CoreRules.PROJECT_REMOVE) + .build(); + HepPlanner postPlanner = new HepPlanner(postProgram); + postPlanner.setRoot(decorrelateNode); + return postPlanner.findBestExp(); + } + + private RelNode correlateElimination(RelNode rel) { + if (rel instanceof Correlate) { + final Correlate correlate = (Correlate) rel; + final RelNode newLeft; + if (hasParent) { + newLeft = unnest(correlate.getLeft()); + } else { + newLeft = decorrelateQuery(correlate.getLeft(), builder); + } + + UnnestInfo leftInfo = mapRelToUnnestInfo.get(correlate.getLeft()); + TreeMap<CorDef, Integer> corDefOutputs = new TreeMap<>(); + Map<Integer, Integer> oldToNewOutputs = new HashMap<>(); + for (int i = 0; i < correlate.getLeft().getRowType().getFieldCount(); i++) { + int newColumnIndex = leftInfo == null ? i : requireNonNull(leftInfo.oldToNewOutputs.get(i)); + oldToNewOutputs.put(i, newColumnIndex); + if (correlate.getRequiredColumns().get(i)) { + CorDef corDef = new CorDef(correlate.getCorrelationId(), i); + corDefs.add(corDef); + corDefOutputs.put(corDef, newColumnIndex); + } + } + if (leftInfo != null) { + corDefOutputs.putAll(leftInfo.corDefOutputs); + } + leftInfo = new UnnestInfo(correlate.getLeft(), newLeft, corDefOutputs, oldToNewOutputs); + dedupFreeVarsNode = generateDedupFreeVarsNode(newLeft, leftInfo); + + detectCorrelatedExpressions(correlate.getRight()); + emptyOutputSensitive |= correlate.getJoinType() == JoinRelType.MARK; + RelNode newRight = unnest(correlate.getRight()); + UnnestInfo rightInfo = requireNonNull(mapRelToUnnestInfo.get(correlate.getRight())); + + builder.push(newLeft).push(newRight); + RexNode unnestedJoinCondition + = createUnnestedJoinCondition(correlate.getCondition(), leftInfo, rightInfo, true); + RelNode unnestedRel = builder.join(correlate.getJoinType(), unnestedJoinCondition).build(); + + if (!hasParent) { + builder.push(unnestedRel); + UnnestInfo unnestInfo = + createJoinUnnestInfo( + leftInfo, + rightInfo, + correlate, + unnestedRel, + correlate.getJoinType()); + List<RexNode> projects + = builder.fields(new ArrayList<>(unnestInfo.oldToNewOutputs.values())); + unnestedRel = builder.project(projects).build(); + } + return unnestedRel; + } else { + for (int i = 0; i < rel.getInputs().size(); i++) { + rel.replaceInput(i, correlateElimination(rel.getInput(i))); + } + } + return rel; + } + + private RelNode generateDedupFreeVarsNode(RelNode newLeft, UnnestInfo leftInfo) { + List<Integer> columnIndexes = new ArrayList<>(); + for (CorDef corDef : corDefs) { + if (leftInfo != null) { + int fieldIndex = requireNonNull(leftInfo.corDefOutputs.get(corDef)); + columnIndexes.add(fieldIndex); + } else { + columnIndexes.add(corDef.field); + } + } + List<RexNode> inputRefs = builder.push(newLeft) + .fields(columnIndexes); + return builder.project(inputRefs).distinct().build(); + } + + private boolean detectCorrelatedExpressions(RelNode rel) { + boolean hasCorrelation = false; + for (RelNode input : rel.getInputs()) { + hasCorrelation |= detectCorrelatedExpressions(input); + } + if (!hasCorrelation) { + RelOptUtil.VariableUsedVisitor variableUsedVisitor = + new RelOptUtil.VariableUsedVisitor(null); + rel.accept(variableUsedVisitor); + Set<CorrelationId> corrIdSet + = corDefs.stream() + .map(corDef -> corDef.corr) + .collect(ImmutableSet.toImmutableSet()); + hasCorrelation = + !variableUsedVisitor.variables.isEmpty() + && corrIdSet.containsAll(variableUsedVisitor.variables); + } + hasCorrelatedExpressions.put(rel, hasCorrelation); + return hasCorrelation; + } + + private RexNode createUnnestedJoinCondition( + RexNode oriCondition, + UnnestInfo leftInfo, + UnnestInfo rightInfo, + boolean needNatureJoinCond) { + Map<Integer, Integer> virtualOldToNewOutputs = new HashMap<>(); + int oriLeftFieldCount = leftInfo.oldRel.getRowType().getFieldCount(); + int newLeftFieldCount = leftInfo.r.getRowType().getFieldCount(); + virtualOldToNewOutputs.putAll(leftInfo.oldToNewOutputs); + rightInfo.oldToNewOutputs.forEach((oriIndex, newIndex) -> + virtualOldToNewOutputs.put( + requireNonNull(oriIndex, "oriIndex") + oriLeftFieldCount, + requireNonNull(newIndex, "newIndex") + newLeftFieldCount)); + + TreeMap<CorDef, Integer> virtualCorDefOutputs = new TreeMap<>(); + if (!leftInfo.corDefOutputs.isEmpty()) { + virtualCorDefOutputs.putAll(leftInfo.corDefOutputs); + } else if (!rightInfo.corDefOutputs.isEmpty()) { + rightInfo.corDefOutputs.forEach((corDef, index) -> + virtualCorDefOutputs.put(corDef, index + newLeftFieldCount)); + } else { + throw new IllegalArgumentException("The UnnestInfo for both sides of Join/Correlate that has " + + "correlation should not all be empty."); + } + + RelNode virtualOldRel = builder.push(leftInfo.oldRel).push(rightInfo.oldRel) + .join(JoinRelType.INNER) + .build(); + RelNode virtualNewRel = builder.push(leftInfo.r).push(rightInfo.r) + .join(JoinRelType.INNER) + .build(); + UnnestInfo virtualInfo = + new UnnestInfo(virtualOldRel, virtualNewRel, virtualCorDefOutputs, virtualOldToNewOutputs); + RexNode rewriteOriCondition = CorrelatedExprRewriter.rewrite(oriCondition, virtualInfo); + List<RexNode> unnestedJoinConditions = new ArrayList<>(); + unnestedJoinConditions.add(rewriteOriCondition); + + if (needNatureJoinCond) { + for (CorDef corDef : corDefs) { + int leftIndex = requireNonNull(leftInfo.corDefOutputs.get(corDef)); + RelDataType leftColumnType + = leftInfo.r.getRowType().getFieldList().get(leftIndex).getType(); + int rightIndex = requireNonNull(rightInfo.corDefOutputs.get(corDef)); + RelDataType rightColumnType + = rightInfo.r.getRowType().getFieldList().get(rightIndex).getType(); + RexNode notDistinctFrom = + builder.isNotDistinctFrom( + new RexInputRef(leftIndex, leftColumnType), + new RexInputRef(rightIndex + newLeftFieldCount, rightColumnType)); + unnestedJoinConditions.add(notDistinctFrom); + } + } + return RexUtil.composeConjunction(builder.getRexBuilder(), unnestedJoinConditions); + } + + private UnnestInfo createJoinUnnestInfo( + UnnestInfo leftInfo, + UnnestInfo rightInfo, + RelNode oriJoinNode, + RelNode unnestedJoinNode, + JoinRelType joinRelType) { + Map<Integer, Integer> oldToNewOutputs = new HashMap<>(); + oldToNewOutputs.putAll(leftInfo.oldToNewOutputs); + int oriLeftFieldCount = leftInfo.oldRel.getRowType().getFieldCount(); + int newLeftFieldCount = leftInfo.r.getRowType().getFieldCount(); + switch (joinRelType) { + case SEMI: + case ANTI: + break; + case MARK: + oldToNewOutputs.put(oriLeftFieldCount, newLeftFieldCount); + break; + default: + rightInfo.oldToNewOutputs.forEach((oriIndex, newIndex) -> + oldToNewOutputs.put( + requireNonNull(oriIndex, "oriIndex") + oriLeftFieldCount, + requireNonNull(newIndex, "newIndex") + newLeftFieldCount)); + break; + } + + TreeMap<CorDef, Integer> corDefOutputs = new TreeMap<>(); + if (!leftInfo.corDefOutputs.isEmpty()) { + corDefOutputs.putAll(leftInfo.corDefOutputs); + } else if (!rightInfo.corDefOutputs.isEmpty()) { + Litmus.THROW.check(joinRelType.projectsRight(), + "If the joinType doesn't project right, its left side must have UnnestInfo."); + rightInfo.corDefOutputs.forEach((corDef, index) -> + corDefOutputs.put(corDef, index + newLeftFieldCount)); + } else { + throw new IllegalArgumentException("The UnnestInfo for both sides of Join/Correlate that has " + + "correlation should not all be empty."); + } + return new UnnestInfo(oriJoinNode, unnestedJoinNode, corDefOutputs, oldToNewOutputs); + } + + private RelNode unnest(RelNode rel) { + if (!requireNonNull(hasCorrelatedExpressions.get(rel))) { + RelNode newRel + = builder.push(decorrelateQuery(rel, builder)) + .push(dedupFreeVarsNode) + .join(JoinRelType.INNER) + .build(); + Map<Integer, Integer> oldToNewOutputs = new HashMap<>(); + IntStream.range(0, rel.getRowType().getFieldCount()) + .forEach(i -> oldToNewOutputs.put(i, i)); + + int offset = rel.getRowType().getFieldCount(); + TreeMap<CorDef, Integer> corDefOutputs = new TreeMap<>(); + for (CorDef corDef : corDefs) { + corDefOutputs.put(corDef, offset++); + } + + UnnestInfo unnestInfo + = new UnnestInfo(rel, newRel, corDefOutputs, oldToNewOutputs); + mapRelToUnnestInfo.put(rel, unnestInfo); + return newRel; + } + return dispatcher.invoke(rel); + } + + public RelNode unnestInternal(Filter filter) { + RelNode newInput = unnest(filter.getInput()); + UnnestInfo inputInfo = requireNonNull(mapRelToUnnestInfo.get(filter.getInput())); + RexNode newCondition = + CorrelatedExprRewriter.rewrite(filter.getCondition(), inputInfo); + RelNode newFilter = builder.push(newInput).filter(newCondition).build(); + + UnnestInfo unnestInfo + = new UnnestInfo(filter, newFilter, inputInfo.corDefOutputs, inputInfo.oldToNewOutputs); + mapRelToUnnestInfo.put(filter, unnestInfo); + return newFilter; + } + + public RelNode unnestInternal(Project project) { + for (RexNode expr : project.getProjects()) { + if (!Strong.isStrong(expr)) { + emptyOutputSensitive = true; + } + } + RelNode newInput = unnest(project.getInput()); + UnnestInfo inputInfo = requireNonNull(mapRelToUnnestInfo.get(project.getInput())); + List<RexNode> newProjects + = CorrelatedExprRewriter.rewrite(project.getProjects(), inputInfo); + + int oriFieldCount = newProjects.size(); + Map<Integer, Integer> oldToNewOutputs = new HashMap<>(); + IntStream.range(0, oriFieldCount).forEach(i -> oldToNewOutputs.put(i, i)); + + builder.push(newInput); + TreeMap<CorDef, Integer> corDefOutputs = new TreeMap<>(); + for (CorDef corDef : corDefs) { + newProjects.add(builder.field(requireNonNull(inputInfo.corDefOutputs.get(corDef)))); + corDefOutputs.put(corDef, oriFieldCount++); + } + RelNode newProject = builder.project(newProjects, ImmutableList.of(), true).build(); + UnnestInfo unnestInfo + = new UnnestInfo(project, newProject, corDefOutputs, oldToNewOutputs); + mapRelToUnnestInfo.put(project, unnestInfo); + return newProject; + } + + public RelNode unnestInternal(Aggregate aggregate) { + RelNode newInput = unnest(aggregate.getInput()); + UnnestInfo inputUnnestInfo = + requireNonNull(mapRelToUnnestInfo.get(aggregate.getInput())); + builder.push(newInput); + + ImmutableBitSet.Builder corKeyBuilder = ImmutableBitSet.builder(); + for (CorDef corDef : corDefs) { + int corKeyIndex = requireNonNull(inputUnnestInfo.corDefOutputs.get(corDef)); + corKeyBuilder.set(corKeyIndex); + } + ImmutableBitSet corKeyBitSet = corKeyBuilder.build(); + ImmutableBitSet newGroupSet + = aggregate.getGroupSet().permute(inputUnnestInfo.oldToNewOutputs) + .union(corKeyBitSet); + List<ImmutableBitSet> newGroupSets = new ArrayList<>(); + for (ImmutableBitSet bitSet : aggregate.getGroupSets()) { + ImmutableBitSet newBitSet + = bitSet.permute(inputUnnestInfo.oldToNewOutputs).union(corKeyBitSet); + newGroupSets.add(newBitSet); + } + + boolean hasCountFunction = false; + List<AggregateCall> permutedAggCalls = new ArrayList<>(); + Mappings.TargetMapping targetMapping = + Mappings.target( + inputUnnestInfo.oldToNewOutputs, + inputUnnestInfo.oldRel.getRowType().getFieldCount(), + inputUnnestInfo.r.getRowType().getFieldCount()); + for (AggregateCall aggCall : aggregate.getAggCallList()) { + hasCountFunction |= aggCall.getAggregation() instanceof SqlCountAggFunction; + permutedAggCalls.add(aggCall.transform(targetMapping)); + } + RelNode newAggregate + = builder.aggregate(builder.groupKey(newGroupSet, newGroupSets), permutedAggCalls).build(); + + Map<Integer, Integer> oldToNewOutputs = new HashMap<>(); + for (int groupKey : aggregate.getGroupSet()) { + int oriIndex = aggregate.getGroupSet().indexOf(groupKey); + int newIndex = newGroupSet.indexOf(groupKey); + oldToNewOutputs.put(oriIndex, newIndex); + } + for (int i = 0; i < aggregate.getAggCallList().size(); i++) { + oldToNewOutputs.put( + aggregate.getGroupCount() + i, + newGroupSet.cardinality() + i); + } + TreeMap<CorDef, Integer> corDefOutputs = new TreeMap<>(); + for (CorDef corDef : corDefs) { + int index = requireNonNull(inputUnnestInfo.corDefOutputs.get(corDef)); + corDefOutputs.put(corDef, newGroupSet.indexOf(index)); + } + + if (aggregate.hasEmptyGroup() + && (emptyOutputSensitive || hasCountFunction)) { + builder.push(dedupFreeVarsNode).push(newAggregate); + List<RexNode> leftJoinConditions = new ArrayList<>(); + int freeVarsIndex = 0; + for (CorDef corDef : corDefs) { + RexNode notDistinctFrom = + builder.isNotDistinctFrom( + builder.field(2, 0, freeVarsIndex), + builder.field(2, 1, requireNonNull(corDefOutputs.get(corDef)))); + leftJoinConditions.add(notDistinctFrom); + + corDefOutputs.put(corDef, freeVarsIndex++); + } + builder.join(JoinRelType.LEFT, leftJoinConditions); + + // rewrite COUNT to case when + List<RexNode> aggCallProjects = new ArrayList<>(); + int aggCallStartIndex = + dedupFreeVarsNode.getRowType().getFieldCount() + newGroupSet.cardinality(); + for (int i = 0; i < permutedAggCalls.size(); i++) { + int index = aggCallStartIndex + i; + SqlAggFunction aggregation = permutedAggCalls.get(i).getAggregation(); + if (aggregation instanceof SqlCountAggFunction) { + RexNode caseWhenRewrite = + builder.call( + SqlStdOperatorTable.CASE, + builder.isNotNull(builder.field(index)), + builder.field(index), + builder.literal(0)); + aggCallProjects.add(caseWhenRewrite); + } else { + aggCallProjects.add(builder.field(index)); + } + } + List<RexNode> projects = + new ArrayList<>(builder.fields(ImmutableBitSet.range(0, aggCallStartIndex))); + projects.addAll(aggCallProjects); + newAggregate = builder.project(projects).build(); + + + for (Map.Entry<Integer, Integer> entry : oldToNewOutputs.entrySet()) { + int value = requireNonNull(entry.getValue()); + entry.setValue(value + corDefs.size()); + } + } + UnnestInfo unnestInfo + = new UnnestInfo(aggregate, newAggregate, corDefOutputs, oldToNewOutputs); + mapRelToUnnestInfo.put(aggregate, unnestInfo); + return newAggregate; + } + + public RelNode unnestInternal(Sort sort) { + RelNode newInput = unnest(sort.getInput()); + UnnestInfo inputInfo = + requireNonNull(mapRelToUnnestInfo.get(sort.getInput())); + Mappings.TargetMapping targetMapping = + Mappings.target( + inputInfo.oldToNewOutputs, + inputInfo.oldRel.getRowType().getFieldCount(), + inputInfo.r.getRowType().getFieldCount()); + RelCollation shiftCollation = sort.getCollation().apply(targetMapping); + builder.push(newInput); + + if (!sort.collation.getFieldCollations().isEmpty() + && (sort.offset != null || sort.fetch != null)) { + List<RexNode> partitionKeys = new ArrayList<>(); + for (CorDef corDef : corDefs) { + int partitionKeyIndex = requireNonNull(inputInfo.corDefOutputs.get(corDef)); + partitionKeys.add(builder.field(partitionKeyIndex)); + } + RexNode rowNumber = builder.aggregateCall(SqlStdOperatorTable.ROW_NUMBER) + .over() + .partitionBy(partitionKeys) + .orderBy(builder.fields(shiftCollation)) + .toRex(); + List<RexNode> projectsWithRowNumber = new ArrayList<>(builder.fields()); + projectsWithRowNumber.add(rowNumber); + builder.project(projectsWithRowNumber); + + List<RexNode> conditions = new ArrayList<>(); + if (sort.offset != null) { + RexNode greaterThenLowerBound = + builder.call( + SqlStdOperatorTable.GREATER_THAN, + builder.field(projectsWithRowNumber.size() - 1), + sort.offset); + conditions.add(greaterThenLowerBound); + } + if (sort.fetch != null) { + RexNode upperBound = sort.offset == null + ? sort.fetch + : builder.call(SqlStdOperatorTable.PLUS, sort.offset, sort.fetch); + RexNode lessThenOrEqualUpperBound = + builder.call( + SqlStdOperatorTable.LESS_THAN_OR_EQUAL, + builder.field(projectsWithRowNumber.size() - 1), + upperBound); + conditions.add(lessThenOrEqualUpperBound); + } + builder.filter(conditions); + } else { + builder.sortLimit(sort.offset, sort.fetch, builder.fields(shiftCollation)); + } + RelNode newSort = builder.build(); + UnnestInfo unnestInfo + = new UnnestInfo(sort, newSort, inputInfo.corDefOutputs, inputInfo.oldToNewOutputs); + mapRelToUnnestInfo.put(sort, unnestInfo); + return newSort; + } + + public RelNode unnestInternal(Correlate correlate) { + TopDownGeneralDecorrelator subDecorrelator = createSubDecorrelator(); + Join newJoin = (Join) subDecorrelator.correlateElimination(correlate); + + UnnestInfo leftInfo + = requireNonNull(subDecorrelator.mapRelToUnnestInfo.get(correlate.getLeft())); + UnnestInfo rightInfo + = requireNonNull(subDecorrelator.mapRelToUnnestInfo.get(correlate.getRight())); + UnnestInfo unnestInfo + = createJoinUnnestInfo(leftInfo, rightInfo, correlate, newJoin, correlate.getJoinType()); + mapRelToUnnestInfo.put(correlate, unnestInfo); + return newJoin; + } + + public RelNode unnestInternal(Join join) { + boolean leftHasCorrelation = + requireNonNull(hasCorrelatedExpressions.get(join.getLeft())); + boolean rightHasCorrelation = + requireNonNull(hasCorrelatedExpressions.get(join.getRight())); + boolean pushDownToLeft = false; + boolean pushDownToRight = false; + RelNode newLeft; + RelNode newRight; + UnnestInfo leftInfo; + UnnestInfo rightInfo; + + if (!leftHasCorrelation && !join.getJoinType().generatesNullsOnRight() + && join.getJoinType().projectsRight()) { + newLeft = decorrelateQuery(join.getLeft(), builder); + Map<Integer, Integer> leftOldToNewOutputs = new HashMap<>(); + IntStream.range(0, newLeft.getRowType().getFieldCount()) + .forEach(i -> leftOldToNewOutputs.put(i, i)); + leftInfo = new UnnestInfo(join.getLeft(), newLeft, new TreeMap<>(), leftOldToNewOutputs); + } else { + newLeft = unnest(join.getLeft()); + pushDownToLeft = true; + leftInfo = requireNonNull(mapRelToUnnestInfo.get(join.getLeft())); + } + if (!rightHasCorrelation && !join.getJoinType().generatesNullsOnLeft()) { + newRight = decorrelateQuery(join.getRight(), builder); + Map<Integer, Integer> rightOldToNewOutputs = new HashMap<>(); + IntStream.range(0, newRight.getRowType().getFieldCount()) + .forEach(i -> rightOldToNewOutputs.put(i, i)); + rightInfo = new UnnestInfo(join.getRight(), newRight, new TreeMap<>(), rightOldToNewOutputs); + } else { + emptyOutputSensitive |= join.getJoinType() == JoinRelType.MARK; + newRight = unnest(join.getRight()); + pushDownToRight = true; + rightInfo = requireNonNull(mapRelToUnnestInfo.get(join.getRight())); + } + + builder.push(newLeft).push(newRight); + RexNode newJoinCondition = + createUnnestedJoinCondition( + join.getCondition(), + leftInfo, + rightInfo, + pushDownToLeft && pushDownToRight); + RelNode newJoin = builder.join(join.getJoinType(), newJoinCondition).build(); + UnnestInfo unnestInfo = + createJoinUnnestInfo( + leftInfo, + rightInfo, + join, + newJoin, + join.getJoinType()); + mapRelToUnnestInfo.put(join, unnestInfo); + return newJoin; + } + + public RelNode unnestInternal(SetOp setOp) { + List<RelNode> newInputs = new ArrayList<>(); + for (RelNode input : setOp.getInputs()) { + RelNode newInput = unnest(input); + builder.push(newInput); + + UnnestInfo inputInfo = requireNonNull(mapRelToUnnestInfo.get(input)); + List<Integer> projectIndexes = new ArrayList<>(); + for (int i = 0; i < inputInfo.oldRel.getRowType().getFieldCount(); i++) { + projectIndexes.add(requireNonNull(inputInfo.oldToNewOutputs.get(i))); + } + for (CorDef corDef : corDefs) { + projectIndexes.add(requireNonNull(inputInfo.corDefOutputs.get(corDef))); + } + builder.project(builder.fields(projectIndexes)); + newInputs.add(builder.build()); + } + builder.pushAll(newInputs); + switch (setOp.kind) { + case UNION: + builder.union(setOp.all, newInputs.size()); + break; + case INTERSECT: + builder.intersect(setOp.all, newInputs.size()); + break; + case EXCEPT: + builder.minus(setOp.all, newInputs.size()); + break; + } + RelNode newSetOp = builder.build(); + + int oriSetOpFieldCount = setOp.getRowType().getFieldCount(); + Map<Integer, Integer> oldToNewOutputs = new HashMap<>(); + IntStream.range(0, oriSetOpFieldCount).forEach(i -> oldToNewOutputs.put(i, i)); + TreeMap<CorDef, Integer> corDefOutputs = new TreeMap<>(); + for (CorDef corDef : corDefs) { + corDefOutputs.put(corDef, oriSetOpFieldCount++); + } + UnnestInfo unnestInfo = new UnnestInfo(setOp, newSetOp, corDefOutputs, oldToNewOutputs); + mapRelToUnnestInfo.put(setOp, unnestInfo); + return newSetOp; + } + + public RelNode unnestInternal(RelNode other) { + throw new UnsupportedOperationException("Top-down general decorrelator does not support: " + + other.getClass().getSimpleName()); + } + + /** + * Rewrites correlated expressions, window function and shift input references. + */ + static class CorrelatedExprRewriter extends RexShuttle { + final UnnestInfo unnestInfo; + + CorrelatedExprRewriter(UnnestInfo unnestInfo) { + this.unnestInfo = unnestInfo; + } + + static RexNode rewrite( + RexNode expr, + UnnestInfo unnestInfo) { + CorrelatedExprRewriter rewriter = new CorrelatedExprRewriter(unnestInfo); + return expr.accept(rewriter); + } + + static List<RexNode> rewrite( + List<RexNode> exprs, + UnnestInfo unnestInfo) { + CorrelatedExprRewriter rewriter = new CorrelatedExprRewriter(unnestInfo); + return new ArrayList<>(rewriter.apply(exprs)); + } + + @Override public RexNode visitInputRef(RexInputRef inputRef) { + int newIndex = requireNonNull(unnestInfo.oldToNewOutputs.get(inputRef.getIndex())); + if (newIndex == inputRef.getIndex()) { + return inputRef; + } + return new RexInputRef(newIndex, inputRef.getType()); + } + + @Override public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { + if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) { + RexCorrelVariable v = + (RexCorrelVariable) fieldAccess.getReferenceExpr(); + CorDef corDef = new CorDef(v.id, fieldAccess.getField().getIndex()); + int newIndex = requireNonNull(unnestInfo.corDefOutputs.get(corDef)); + return new RexInputRef(newIndex, fieldAccess.getType()); + } + return super.visitFieldAccess(fieldAccess); + } + + @Override public RexWindow visitWindow(RexWindow window) { + RexWindow shiftedWindow = super.visitWindow(window); + List<RexNode> newPartitionKeys = new ArrayList<>(shiftedWindow.partitionKeys); + for (Integer corIndex : unnestInfo.corDefOutputs.values()) { + RexInputRef inputRef = + new RexInputRef( + corIndex, + unnestInfo.r.getRowType().getFieldList().get(corIndex).getType()); + newPartitionKeys.add(inputRef); + } + return unnestInfo.r.getCluster().getRexBuilder().makeWindow( + newPartitionKeys, + window.orderKeys, + window.getLowerBound(), + window.getUpperBound(), + window.isRows(), + window.getExclude()); + } + } + + /** + * Unnesting information. Review Comment: I believe that there are optimizations that work for joins, like predicate push-down, but do not work for aggregates. But this can be improved later if necessary. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
