https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/174094
>From fba22ca1cda0e1fa231c0972ddff6d4829a9eb3a Mon Sep 17 00:00:00 2001 From: Matthias Springer <[email protected]> Date: Wed, 31 Dec 2025 14:07:51 +0000 Subject: [PATCH] [mlir][draft] Consolidate patterns into RegionBranchOpInterface patterns fix some tests reorganize code --- .../mlir/Interfaces/ControlFlowInterfaces.h | 9 + .../mlir/Interfaces/ControlFlowInterfaces.td | 5 + mlir/lib/Dialect/SCF/IR/SCF.cpp | 830 +----------------- mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 483 ++++++++++ mlir/test/Dialect/SCF/canonicalize.mlir | 24 +- mlir/test/Transforms/remove-dead-values.mlir | 8 +- 6 files changed, 529 insertions(+), 830 deletions(-) diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h index 566f4b8fadb5d..ea85b2d1b5cb6 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -16,6 +16,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/raw_ostream.h" @@ -188,6 +189,8 @@ LogicalResult verifyTypesAlongControlFlowEdges(Operation *op); /// possible successors.) Operands that not forwarded at all are not present in /// the mapping. using RegionBranchSuccessorMapping = DenseMap<OpOperand *, SmallVector<Value>>; +using RegionBranchInverseSuccessorMapping = + DenseMap<Value, SmallVector<OpOperand *>>; /// This class represents a successor of a region. A region successor can either /// be another region, or the parent operation. If the successor is a region, @@ -350,6 +353,12 @@ Region *getEnclosingRepetitiveRegion(Operation *op); /// exists. Region *getEnclosingRepetitiveRegion(Value value); +/// Populate canonicalization patterns that simplify successor operands/inputs +/// of region branch operations. Only operations with the given name are +/// matched. +void populateRegionBranchOpInterfaceCanonicalizationPatterns( + RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit = 1); + //===----------------------------------------------------------------------===// // ControlFlow Traits //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td index 2e654ba04ffe5..70aed9e1e11c6 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -355,6 +355,11 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { ::mlir::RegionBranchSuccessorMapping &mapping, std::optional<::mlir::RegionBranchPoint> src = std::nullopt); + /// Build a mapping from successor inputs to successor operands. This is + /// the same as "getSuccessorOperandInputMapping", but inverted. + void getSuccessorInputOperandMapping( + ::mlir::RegionBranchInverseSuccessorMapping &mapping); + /// Return all possible region branch points: the region branch op itself /// and all region branch terminators. ::llvm::SmallVector<::mlir::RegionBranchPoint> getAllRegionBranchPoints(); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 8803a6d136f7a..95a854b655a53 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -291,102 +291,11 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { } }; -// Pattern to eliminate ExecuteRegionOp results which forward external -// values from the region. In case there are multiple yield operations, -// all of them must have the same operands in order for the pattern to be -// applicable. -struct ExecuteRegionForwardingEliminator - : public OpRewritePattern<ExecuteRegionOp> { - using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(ExecuteRegionOp op, - PatternRewriter &rewriter) const override { - if (op.getNumResults() == 0) - return failure(); - - SmallVector<Operation *> yieldOps; - for (Block &block : op.getRegion()) { - if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator())) - yieldOps.push_back(yield.getOperation()); - } - - if (yieldOps.empty()) - return failure(); - - // Check if all yield operations have the same operands. - auto yieldOpsOperands = yieldOps[0]->getOperands(); - for (auto *yieldOp : yieldOps) { - if (yieldOp->getOperands() != yieldOpsOperands) - return failure(); - } - - SmallVector<Value> externalValues; - SmallVector<Value> internalValues; - SmallVector<Value> opResultsToReplaceWithExternalValues; - SmallVector<Value> opResultsToKeep; - for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) { - if (isValueFromInsideRegion(yieldedValue, op)) { - internalValues.push_back(yieldedValue); - opResultsToKeep.push_back(op.getResult(index)); - } else { - externalValues.push_back(yieldedValue); - opResultsToReplaceWithExternalValues.push_back(op.getResult(index)); - } - } - // No yielded external values - nothing to do. - if (externalValues.empty()) - return failure(); - - // There are yielded external values - create a new execute_region returning - // just the internal values. - SmallVector<Type> resultTypes; - for (Value value : internalValues) - resultTypes.push_back(value.getType()); - auto newOp = - ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes)); - newOp->setAttrs(op->getAttrs()); - - // Move old op's region to the new operation. - rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), - newOp.getRegion().end()); - - // Replace all yield operations with a new yield operation with updated - // results. scf.execute_region must have at least one yield operation. - for (auto *yieldOp : yieldOps) { - rewriter.setInsertionPoint(yieldOp); - rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, - ValueRange(internalValues)); - } - - // Replace the old operation with the external values directly. - rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues, - externalValues); - // Replace the old operation's remaining results with the new operation's - // results. - rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults()); - rewriter.eraseOp(op); - return success(); - } - -private: - bool isValueFromInsideRegion(Value value, - ExecuteRegionOp executeRegionOp) const { - // Check if the value is defined within the execute_region - if (Operation *defOp = value.getDefiningOp()) - return &executeRegionOp.getRegion() == defOp->getParentRegion(); - - // If it's a block argument, check if it's from within the region - if (BlockArgument blockArg = dyn_cast<BlockArgument>(value)) - return &executeRegionOp.getRegion() == blockArg.getParentRegion(); - - return false; // Value is from outside the region - } -}; - void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner, - ExecuteRegionForwardingEliminator>(context); + results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context); + populateRegionBranchOpInterfaceCanonicalizationPatterns( + results, ExecuteRegionOp::getOperationName()); } void ExecuteRegionOp::getSuccessorRegions( @@ -989,146 +898,6 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, } namespace { -// Fold away ForOp iter arguments when: -// 1) The op yields the iter arguments. -// 2) The argument's corresponding outer region iterators (inputs) are yielded. -// 3) The iter arguments have no use and the corresponding (operation) results -// have no use. -// -// These arguments must be defined outside of the ForOp region and can just be -// forwarded after simplifying the op inits, yields and returns. -// -// The implementation uses `inlineBlockBefore` to steal the content of the -// original ForOp and avoid cloning. -struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> { - using OpRewritePattern<scf::ForOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::ForOp forOp, - PatternRewriter &rewriter) const final { - bool canonicalize = false; - - // An internal flat vector of block transfer - // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to - // transformed block argument mappings. This plays the role of a - // IRMapping for the particular use case of calling into - // `inlineBlockBefore`. - int64_t numResults = forOp.getNumResults(); - SmallVector<bool, 4> keepMask; - keepMask.reserve(numResults); - SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues, - newResultValues; - newBlockTransferArgs.reserve(1 + numResults); - newBlockTransferArgs.push_back(Value()); // iv placeholder with null value - newIterArgs.reserve(forOp.getInitArgs().size()); - newYieldValues.reserve(numResults); - newResultValues.reserve(numResults); - DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg; - for (auto [init, arg, result, yielded] : - llvm::zip(forOp.getInitArgs(), // iter from outside - forOp.getRegionIterArgs(), // iter inside region - forOp.getResults(), // op results - forOp.getYieldedValues() // iter yield - )) { - // Forwarded is `true` when: - // 1) The region `iter` argument is yielded. - // 2) The region `iter` argument the corresponding input is yielded. - // 3) The region `iter` argument has no use, and the corresponding op - // result has no use. - bool forwarded = (arg == yielded) || (init == yielded) || - (arg.use_empty() && result.use_empty()); - if (forwarded) { - canonicalize = true; - keepMask.push_back(false); - newBlockTransferArgs.push_back(init); - newResultValues.push_back(init); - continue; - } - - // Check if a previous kept argument always has the same values for init - // and yielded values. - if (auto it = initYieldToArg.find({init, yielded}); - it != initYieldToArg.end()) { - canonicalize = true; - keepMask.push_back(false); - auto [sameArg, sameResult] = it->second; - rewriter.replaceAllUsesWith(arg, sameArg); - rewriter.replaceAllUsesWith(result, sameResult); - // The replacement value doesn't matter because there are no uses. - newBlockTransferArgs.push_back(init); - newResultValues.push_back(init); - continue; - } - - // This value is kept. - initYieldToArg.insert({{init, yielded}, {arg, result}}); - keepMask.push_back(true); - newIterArgs.push_back(init); - newYieldValues.push_back(yielded); - newBlockTransferArgs.push_back(Value()); // placeholder with null value - newResultValues.push_back(Value()); // placeholder with null value - } - - if (!canonicalize) - return failure(); - - scf::ForOp newForOp = - scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), - forOp.getUpperBound(), forOp.getStep(), newIterArgs, - /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp()); - newForOp->setAttrs(forOp->getAttrs()); - Block &newBlock = newForOp.getRegion().front(); - - // Replace the null placeholders with newly constructed values. - newBlockTransferArgs[0] = newBlock.getArgument(0); // iv - for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size(); - idx != e; ++idx) { - Value &blockTransferArg = newBlockTransferArgs[1 + idx]; - Value &newResultVal = newResultValues[idx]; - assert((blockTransferArg && newResultVal) || - (!blockTransferArg && !newResultVal)); - if (!blockTransferArg) { - blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx]; - newResultVal = newForOp.getResult(collapsedIdx++); - } - } - - Block &oldBlock = forOp.getRegion().front(); - assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() && - "unexpected argument size mismatch"); - - // No results case: the scf::ForOp builder already created a zero - // result terminator. Merge before this terminator and just get rid of the - // original terminator that has been merged in. - if (newIterArgs.empty()) { - auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator()); - rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs); - rewriter.eraseOp(newBlock.getTerminator()->getPrevNode()); - rewriter.replaceOp(forOp, newResultValues); - return success(); - } - - // No terminator case: merge and rewrite the merged terminator. - auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(mergedTerminator); - SmallVector<Value, 4> filteredOperands; - filteredOperands.reserve(newResultValues.size()); - for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx) - if (keepMask[idx]) - filteredOperands.push_back(mergedTerminator.getOperand(idx)); - scf::YieldOp::create(rewriter, mergedTerminator.getLoc(), - filteredOperands); - }; - - rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs); - auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator()); - cloneFilteredTerminator(mergedYieldOp); - rewriter.eraseOp(mergedYieldOp); - rewriter.replaceOp(forOp, newResultValues); - return success(); - } -}; - /// Rewriting pattern that erases loops that are known not to iterate, replaces /// single-iteration loops with their bodies, and removes empty loops that /// iterate at least once and only return values defined outside of the loop. @@ -1235,13 +1004,13 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> { return failure(); } }; - } // namespace void ForOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>( - context); + results.add<SimplifyTrivialLoops, ForOpTensorCastFolder>(context); + populateRegionBranchOpInterfaceCanonicalizationPatterns( + results, ForOp::getOperationName()); } std::optional<APInt> ForOp::getConstantStep() { @@ -2378,35 +2147,6 @@ void IfOp::getRegionInvocationBounds( } namespace { -// Pattern to remove unused IfOp results. -struct RemoveUnusedResults : public OpRewritePattern<IfOp> { - using OpRewritePattern<IfOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(IfOp op, - PatternRewriter &rewriter) const override { - // Compute the list of unused results. - BitVector toErase(op.getNumResults(), false); - for (auto [idx, result] : llvm::enumerate(op.getResults())) - if (result.use_empty()) - toErase[idx] = true; - if (toErase.none()) - return rewriter.notifyMatchFailure(op, "no results to erase"); - - // Erase results. - auto newOp = cast<scf::IfOp>(rewriter.eraseOpResults(op, toErase)); - - // Erase operands. - rewriter.modifyOpInPlace(newOp.thenYield(), [&]() { - newOp.thenYield()->eraseOperands(toErase); - }); - rewriter.modifyOpInPlace(newOp.elseYield(), [&]() { - newOp.elseYield()->eraseOperands(toErase); - }); - - return success(); - } -}; - struct RemoveStaticCondition : public OpRewritePattern<IfOp> { using OpRewritePattern<IfOp>::OpRewritePattern; @@ -2977,8 +2717,10 @@ void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add<CombineIfs, CombineNestedIfs, ConditionPropagation, ConvertTrivialIfToSelect, RemoveEmptyElseBranch, - RemoveStaticCondition, RemoveUnusedResults, - ReplaceIfYieldWithConditionOrValue>(context); + RemoveStaticCondition, ReplaceIfYieldWithConditionOrValue>( + context); + populateRegionBranchOpInterfaceCanonicalizationPatterns( + results, IfOp::getOperationName()); } Block *IfOp::thenBlock() { return &getThenRegion().back(); } @@ -3816,390 +3558,6 @@ struct WhileConditionTruth : public OpRewritePattern<WhileOp> { } }; -/// Remove loop invariant arguments from `before` block of scf.while. -/// A before block argument is considered loop invariant if :- -/// 1. i-th yield operand is equal to the i-th while operand. -/// 2. i-th yield operand is k-th after block argument which is (k+1)-th -/// condition operand AND this (k+1)-th condition operand is equal to i-th -/// iter argument/while operand. -/// For the arguments which are removed, their uses inside scf.while -/// are replaced with their corresponding initial value. -/// -/// Eg: -/// INPUT :- -/// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b, -/// ..., %argN_before = %N) -/// { -/// ... -/// scf.condition(%cond) %arg1_before, %arg0_before, -/// %arg2_before, %arg0_before, ... -/// } do { -/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2, -/// ..., %argK_after): -/// ... -/// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN -/// } -/// -/// OUTPUT :- -/// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before = -/// %N) -/// { -/// ... -/// scf.condition(%cond) %b, %a, %arg2_before, %a, ... -/// } do { -/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2, -/// ..., %argK_after): -/// ... -/// scf.yield %arg1_after, ..., %argN -/// } -/// -/// EXPLANATION: -/// We iterate over each yield operand. -/// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand -/// %arg0_before, which in turn is the 0-th iter argument. So we -/// remove 0-th before block argument and yield operand, and replace -/// all uses of the 0-th before block argument with its initial value -/// %a. -/// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial -/// value. So we remove this operand and the corresponding before -/// block argument and replace all uses of 1-th before block argument -/// with %b. -struct RemoveLoopInvariantArgsFromBeforeBlock - : public OpRewritePattern<WhileOp> { - using OpRewritePattern<WhileOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(WhileOp op, - PatternRewriter &rewriter) const override { - Block &afterBlock = *op.getAfterBody(); - Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments(); - ConditionOp condOp = op.getConditionOp(); - OperandRange condOpArgs = condOp.getArgs(); - Operation *yieldOp = afterBlock.getTerminator(); - ValueRange yieldOpArgs = yieldOp->getOperands(); - - bool canSimplify = false; - for (const auto &it : - llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) { - auto index = static_cast<unsigned>(it.index()); - auto [initVal, yieldOpArg] = it.value(); - // If i-th yield operand is equal to the i-th operand of the scf.while, - // the i-th before block argument is a loop invariant. - if (yieldOpArg == initVal) { - canSimplify = true; - break; - } - // If the i-th yield operand is k-th after block argument, then we check - // if the (k+1)-th condition op operand is equal to either the i-th before - // block argument or the initial value of i-th before block argument. If - // the comparison results `true`, i-th before block argument is a loop - // invariant. - auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg); - if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { - Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; - if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { - canSimplify = true; - break; - } - } - } - - if (!canSimplify) - return failure(); - - SmallVector<Value> newInitArgs, newYieldOpArgs; - DenseMap<unsigned, Value> beforeBlockInitValMap; - SmallVector<Location> newBeforeBlockArgLocs; - for (const auto &it : - llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) { - auto index = static_cast<unsigned>(it.index()); - auto [initVal, yieldOpArg] = it.value(); - - // If i-th yield operand is equal to the i-th operand of the scf.while, - // the i-th before block argument is a loop invariant. - if (yieldOpArg == initVal) { - beforeBlockInitValMap.insert({index, initVal}); - continue; - } else { - // If the i-th yield operand is k-th after block argument, then we check - // if the (k+1)-th condition op operand is equal to either the i-th - // before block argument or the initial value of i-th before block - // argument. If the comparison results `true`, i-th before block - // argument is a loop invariant. - auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg); - if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { - Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; - if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { - beforeBlockInitValMap.insert({index, initVal}); - continue; - } - } - } - newInitArgs.emplace_back(initVal); - newYieldOpArgs.emplace_back(yieldOpArg); - newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc()); - } - - { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(yieldOp); - rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs); - } - - auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(), - newInitArgs); - - Block &newBeforeBlock = *rewriter.createBlock( - &newWhile.getBefore(), /*insertPt*/ {}, - ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs); - - Block &beforeBlock = *op.getBeforeBody(); - SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments()); - // For each i-th before block argument we find it's replacement value as :- - // 1. If i-th before block argument is a loop invariant, we fetch it's - // initial value from `beforeBlockInitValMap` by querying for key `i`. - // 2. Else we fetch j-th new before block argument as the replacement - // value of i-th before block argument. - for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) { - // If the index 'i' argument was a loop invariant we fetch it's initial - // value from `beforeBlockInitValMap`. - if (beforeBlockInitValMap.count(i) != 0) - newBeforeBlockArgs[i] = beforeBlockInitValMap[i]; - else - newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++); - } - - rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs); - rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(), - newWhile.getAfter().begin()); - - rewriter.replaceOp(op, newWhile.getResults()); - return success(); - } -}; - -/// Remove loop invariant value from result (condition op) of scf.while. -/// A value is considered loop invariant if the final value yielded by -/// scf.condition is defined outside of the `before` block. We remove the -/// corresponding argument in `after` block and replace the use with the value. -/// We also replace the use of the corresponding result of scf.while with the -/// value. -/// -/// Eg: -/// INPUT :- -/// %res_input:K = scf.while <...> iter_args(%arg0_before = , ..., -/// %argN_before = %N) { -/// ... -/// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ... -/// } do { -/// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after): -/// ... -/// some_func(%arg1_after) -/// ... -/// scf.yield %arg0_after, %arg2_after, ..., %argN_after -/// } -/// -/// OUTPUT :- -/// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) { -/// ... -/// scf.condition(%cond) %arg0, %arg1, ..., %argM -/// } do { -/// ^bb0(%arg0, %arg3, ..., %argM): -/// ... -/// some_func(%a) -/// ... -/// scf.yield %arg0, %b, ..., %argN -/// } -/// -/// EXPLANATION: -/// 1. The 1-th and 2-th operand of scf.condition are defined outside the -/// before block of scf.while, so they get removed. -/// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are -/// replaced by %b. -/// 3. The corresponding after block argument %arg1_after's uses are -/// replaced by %a and %arg2_after's uses are replaced by %b. -struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> { - using OpRewritePattern<WhileOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(WhileOp op, - PatternRewriter &rewriter) const override { - Block &beforeBlock = *op.getBeforeBody(); - ConditionOp condOp = op.getConditionOp(); - OperandRange condOpArgs = condOp.getArgs(); - - bool canSimplify = false; - for (Value condOpArg : condOpArgs) { - // Those values not defined within `before` block will be considered as - // loop invariant values. We map the corresponding `index` with their - // value. - if (condOpArg.getParentBlock() != &beforeBlock) { - canSimplify = true; - break; - } - } - - if (!canSimplify) - return failure(); - - Block::BlockArgListType afterBlockArgs = op.getAfterArguments(); - - SmallVector<Value> newCondOpArgs; - SmallVector<Type> newAfterBlockType; - DenseMap<unsigned, Value> condOpInitValMap; - SmallVector<Location> newAfterBlockArgLocs; - for (const auto &it : llvm::enumerate(condOpArgs)) { - auto index = static_cast<unsigned>(it.index()); - Value condOpArg = it.value(); - // Those values not defined within `before` block will be considered as - // loop invariant values. We map the corresponding `index` with their - // value. - if (condOpArg.getParentBlock() != &beforeBlock) { - condOpInitValMap.insert({index, condOpArg}); - } else { - newCondOpArgs.emplace_back(condOpArg); - newAfterBlockType.emplace_back(condOpArg.getType()); - newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc()); - } - } - - { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(condOp); - rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(), - newCondOpArgs); - } - - auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType, - op.getOperands()); - - Block &newAfterBlock = - *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {}, - newAfterBlockType, newAfterBlockArgLocs); - - Block &afterBlock = *op.getAfterBody(); - // Since a new scf.condition op was created, we need to fetch the new - // `after` block arguments which will be used while replacing operations of - // previous scf.while's `after` blocks. We'd also be fetching new result - // values too. - SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments()); - SmallVector<Value> newWhileResults(afterBlock.getNumArguments()); - for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) { - Value afterBlockArg, result; - // If index 'i' argument was loop invariant we fetch it's value from the - // `condOpInitMap` map. - if (condOpInitValMap.count(i) != 0) { - afterBlockArg = condOpInitValMap[i]; - result = afterBlockArg; - } else { - afterBlockArg = newAfterBlock.getArgument(j); - result = newWhile.getResult(j); - j++; - } - newAfterBlockArgs[i] = afterBlockArg; - newWhileResults[i] = result; - } - - rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs); - rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(), - newWhile.getBefore().begin()); - - rewriter.replaceOp(op, newWhileResults); - return success(); - } -}; - -/// Remove WhileOp results that are also unused in 'after' block. -/// -/// %0:2 = scf.while () : () -> (i32, i64) { -/// %condition = "test.condition"() : () -> i1 -/// %v1 = "test.get_some_value"() : () -> i32 -/// %v2 = "test.get_some_value"() : () -> i64 -/// scf.condition(%condition) %v1, %v2 : i32, i64 -/// } do { -/// ^bb0(%arg0: i32, %arg1: i64): -/// "test.use"(%arg0) : (i32) -> () -/// scf.yield -/// } -/// return %0#0 : i32 -/// -/// becomes -/// %0 = scf.while () : () -> (i32) { -/// %condition = "test.condition"() : () -> i1 -/// %v1 = "test.get_some_value"() : () -> i32 -/// %v2 = "test.get_some_value"() : () -> i64 -/// scf.condition(%condition) %v1 : i32 -/// } do { -/// ^bb0(%arg0: i32): -/// "test.use"(%arg0) : (i32) -> () -/// scf.yield -/// } -/// return %0 : i32 -struct WhileUnusedResult : public OpRewritePattern<WhileOp> { - using OpRewritePattern<WhileOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(WhileOp op, - PatternRewriter &rewriter) const override { - auto term = op.getConditionOp(); - auto afterArgs = op.getAfterArguments(); - auto termArgs = term.getArgs(); - - // Collect results mapping, new terminator args and new result types. - SmallVector<unsigned> newResultsIndices; - SmallVector<Type> newResultTypes; - SmallVector<Value> newTermArgs; - SmallVector<Location> newArgLocs; - bool needUpdate = false; - for (const auto &it : - llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) { - auto i = static_cast<unsigned>(it.index()); - Value result = std::get<0>(it.value()); - Value afterArg = std::get<1>(it.value()); - Value termArg = std::get<2>(it.value()); - if (result.use_empty() && afterArg.use_empty()) { - needUpdate = true; - } else { - newResultsIndices.emplace_back(i); - newTermArgs.emplace_back(termArg); - newResultTypes.emplace_back(result.getType()); - newArgLocs.emplace_back(result.getLoc()); - } - } - - if (!needUpdate) - return failure(); - - { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(term); - rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(), - newTermArgs); - } - - auto newWhile = - WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits()); - - Block &newAfterBlock = *rewriter.createBlock( - &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs); - - // Build new results list and new after block args (unused entries will be - // null). - SmallVector<Value> newResults(op.getNumResults()); - SmallVector<Value> newAfterBlockArgs(op.getNumResults()); - for (const auto &it : llvm::enumerate(newResultsIndices)) { - newResults[it.value()] = newWhile.getResult(it.index()); - newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index()); - } - - rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(), - newWhile.getBefore().begin()); - - Block &afterBlock = *op.getAfterBody(); - rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs); - - rewriter.replaceOp(op, newResults); - return success(); - } -}; - /// Replace operations equivalent to the condition in the do block with true, /// since otherwise the block would not be evaluated. /// @@ -4264,127 +3622,6 @@ struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> { } }; -/// Remove unused init/yield args. -struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> { - using OpRewritePattern<WhileOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(WhileOp op, - PatternRewriter &rewriter) const override { - - if (!llvm::any_of(op.getBeforeArguments(), - [](Value arg) { return arg.use_empty(); })) - return rewriter.notifyMatchFailure(op, "No args to remove"); - - YieldOp yield = op.getYieldOp(); - - // Collect results mapping, new terminator args and new result types. - SmallVector<Value> newYields; - SmallVector<Value> newInits; - llvm::BitVector argsToErase; - - size_t argsCount = op.getBeforeArguments().size(); - newYields.reserve(argsCount); - newInits.reserve(argsCount); - argsToErase.reserve(argsCount); - for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip( - op.getBeforeArguments(), yield.getOperands(), op.getInits())) { - if (beforeArg.use_empty()) { - argsToErase.push_back(true); - } else { - argsToErase.push_back(false); - newYields.emplace_back(yieldValue); - newInits.emplace_back(initValue); - } - } - - Block &beforeBlock = *op.getBeforeBody(); - Block &afterBlock = *op.getAfterBody(); - - beforeBlock.eraseArguments(argsToErase); - - Location loc = op.getLoc(); - auto newWhileOp = - WhileOp::create(rewriter, loc, op.getResultTypes(), newInits, - /*beforeBody*/ nullptr, /*afterBody*/ nullptr); - Block &newBeforeBlock = *newWhileOp.getBeforeBody(); - Block &newAfterBlock = *newWhileOp.getAfterBody(); - - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(yield); - rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields); - - rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, - newBeforeBlock.getArguments()); - rewriter.mergeBlocks(&afterBlock, &newAfterBlock, - newAfterBlock.getArguments()); - - rewriter.replaceOp(op, newWhileOp.getResults()); - return success(); - } -}; - -/// Remove duplicated ConditionOp args. -struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(WhileOp op, - PatternRewriter &rewriter) const override { - ConditionOp condOp = op.getConditionOp(); - ValueRange condOpArgs = condOp.getArgs(); - - llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs); - - if (argsSet.size() == condOpArgs.size()) - return rewriter.notifyMatchFailure(op, "No results to remove"); - - llvm::SmallDenseMap<Value, unsigned> argsMap; - SmallVector<Value> newArgs; - argsMap.reserve(condOpArgs.size()); - newArgs.reserve(condOpArgs.size()); - for (Value arg : condOpArgs) { - if (!argsMap.count(arg)) { - auto pos = static_cast<unsigned>(argsMap.size()); - argsMap.insert({arg, pos}); - newArgs.emplace_back(arg); - } - } - - ValueRange argsRange(newArgs); - - Location loc = op.getLoc(); - auto newWhileOp = - scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(), - /*beforeBody*/ nullptr, - /*afterBody*/ nullptr); - Block &newBeforeBlock = *newWhileOp.getBeforeBody(); - Block &newAfterBlock = *newWhileOp.getAfterBody(); - - SmallVector<Value> afterArgsMapping; - SmallVector<Value> resultsMapping; - for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) { - auto it = argsMap.find(arg); - assert(it != argsMap.end()); - auto pos = it->second; - afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos)); - resultsMapping.emplace_back(newWhileOp->getResult(pos)); - } - - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(condOp); - rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(), - argsRange); - - Block &beforeBlock = *op.getBeforeBody(); - Block &afterBlock = *op.getAfterBody(); - - rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, - newBeforeBlock.getArguments()); - rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping); - rewriter.replaceOp(op, resultsMapping); - return success(); - } -}; - /// If both ranges contain same values return mappping indices from args2 to /// args1. Otherwise return std::nullopt. static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1, @@ -4475,11 +3712,10 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> { void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<RemoveLoopInvariantArgsFromBeforeBlock, - RemoveLoopInvariantValueYielded, WhileConditionTruth, - WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults, - WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>( - context); + results.add<WhileConditionTruth, WhileCmpCond, WhileOpAlignBeforeArgs, + WhileMoveIfDown>(context); + populateRegionBranchOpInterfaceCanonicalizationPatterns( + results, WhileOp::getOperationName()); } //===----------------------------------------------------------------------===// @@ -4654,43 +3890,11 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> { } }; -/// Canonicalization patterns that folds away dead results of -/// "scf.index_switch" ops. -struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> { - using OpRewritePattern<IndexSwitchOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(IndexSwitchOp op, - PatternRewriter &rewriter) const override { - // Find dead results. - BitVector deadResults(op.getNumResults(), false); - for (auto [idx, result] : llvm::enumerate(op.getResults())) - if (result.use_empty()) - deadResults[idx] = true; - if (!deadResults.any()) - return rewriter.notifyMatchFailure(op, "no dead results to fold"); - - // Erase dead results. - auto newOp = - cast<scf::IndexSwitchOp>(rewriter.eraseOpResults(op, deadResults)); - - // Erase operands from yield ops. - auto updateCaseRegion = [&](Region ®ion) { - Operation *terminator = region.front().getTerminator(); - assert(isa<YieldOp>(terminator) && "expected yield op"); - rewriter.modifyOpInPlace( - terminator, [&]() { terminator->eraseOperands(deadResults); }); - }; - updateCaseRegion(newOp.getDefaultRegion()); - for (Region &caseRegion : newOp.getCaseRegions()) - updateCaseRegion(caseRegion); - - return success(); - } -}; - void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context); + results.add<FoldConstantCase>(context); + populateRegionBranchOpInterfaceCanonicalizationPatterns( + results, IndexSwitchOp::getOperationName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp index d393ddb8d8336..87a6aafcda002 100644 --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -10,7 +10,9 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "llvm/ADT/EquivalenceClasses.h" #include "llvm/Support/DebugLog.h" using namespace mlir; @@ -521,6 +523,23 @@ void RegionBranchOpInterface::getSuccessorOperandInputMapping( } } +static RegionBranchInverseSuccessorMapping invertRegionBranchSuccessorMapping( + const RegionBranchSuccessorMapping &operandToInputs) { + RegionBranchInverseSuccessorMapping inputToOperands; + for (const auto &[operand, inputs] : operandToInputs) { + for (Value input : inputs) + inputToOperands[input].push_back(operand); + } + return inputToOperands; +} + +void RegionBranchOpInterface::getSuccessorInputOperandMapping( + RegionBranchInverseSuccessorMapping &mapping) { + RegionBranchSuccessorMapping operandToInputs; + getSuccessorOperandInputMapping(operandToInputs); + mapping = invertRegionBranchSuccessorMapping(operandToInputs); +} + SmallVector<RegionBranchPoint> RegionBranchOpInterface::getAllRegionBranchPoints() { SmallVector<RegionBranchPoint> branchPoints; @@ -583,3 +602,467 @@ Region *mlir::getEnclosingRepetitiveRegion(Value value) { LDBG() << "No enclosing repetitive region found for value"; return nullptr; } + +/// Return "true" if `a` can be used in lieu of `b`, where `b` is a region +/// successor input and `a` is a "possible value" of `b`. Possible values are +/// successor operand values that are (maybe transitively) forwarded to `b`. +static bool isDefinedBefore(Operation *regionBranchOp, Value a, Value b) { + assert((b.getDefiningOp() == regionBranchOp || + b.getParentRegion()->getParentOp() == regionBranchOp) && + "b must be a region successor input"); + + // Case 1: `a` is defined inside of the region branch op. `a` must be + // directly nested in the region branch op. Otherwise, it could not have + // been among the possible values for a region successor input. + if (a.getParentRegion()->getParentOp() == regionBranchOp) { + // Case 1.1: If `b` is a result of the region branch op, `a` is not in + // scope for `b`. + // Example: + // %b = region_op({ + // ^bb0(%a1: ...): + // %a2 = ... + // }) + if (isa<OpResult>(b)) + return false; + + // Case 1.2: `b` is an entry block argument of a region. `a` is in scope + // for `b` only if it is also an entry block argument of the same region. + // Example: + // region_op({ + // ^bb0(%b: ..., %a: ...): + // ... + // }) + assert(isa<BlockArgument>(b) && "b must be a block argument"); + return isa<BlockArgument>(a) && cast<BlockArgument>(a).getOwner() == + cast<BlockArgument>(b).getOwner(); + } + + // Case 2: `a` is defined outside of the region branch op. In that case, we + // can safely assume that `a` was defined before `b`. Otherwise, it could not + // be among the possible values for a region successor input. + // Example: + // { <- %a1 parent region begins here. + // ^bb0(%a1: ...): + // %a2 = ... + // %b1 = reigon_op({ + // ^bb1(%b2: ...): + // ... + // }) + // } + return true; +} + +/// Compute all non-successor input values that a successor input could have +/// based on the given successor input to successor operand mapping. +/// +/// Example 1: +/// %r = scf.for ... iter_args(%arg0 = %0) -> ... { +/// scf.yield %arg0 : ... +/// } +/// getPossibleValuesOfSuccessorInput(%arg0) = {%0} +/// getPossibleValuesOfSuccessorInput(%r) = {%0} +/// +/// Example 2: +/// %r = scf.for ... iter_args(%arg0 = %0) -> ... { +/// ... +/// scf.yield %1 : ... +/// } +/// getPossibleValuesOfSuccessorInput(%arg0) = {%0, %1} +/// getPossibleValuesOfSuccessorInput(%r) = {%0, %1} +static llvm::SmallDenseSet<Value> computePossibleValuesOfSuccessorInput( + Value value, const RegionBranchInverseSuccessorMapping &inputToOperands) { + llvm::SmallDenseSet<Value> possibleValues; + llvm::SmallDenseSet<Value> visited; + SmallVector<Value> worklist; + + // Starting with the given value, trace back all predecessor values (i.e., + // preceding successor operands) and add them to the set of possible values. + // If the successor operand is again a successor input, do not add it to + // result set, but instead continue the traversal. + worklist.push_back(value); + while (!worklist.empty()) { + Value next = worklist.pop_back_val(); + auto it = inputToOperands.find(next); + if (it == inputToOperands.end()) { + possibleValues.insert(next); + continue; + } + for (OpOperand *operand : it->second) + if (visited.insert(operand->get()).second) + worklist.push_back(operand->get()); + } + + return possibleValues; +} + +namespace { +/// Try to make successor inputs dead by replacing their uses with values that +/// are not successor inputs. This pattern enables additional canonicalization +/// opportunities for RemoveDeadRegionBranchOpSuccessorInputs. +/// +/// Example: +/// +/// %r = scf.for ... iter_args(%arg0 = %0) -> ... { +/// scf.yield %arg0 : ... +/// } +/// use(%r) +/// +/// According to `computePossibleValuesOfSuccessorInput`, the only possible +/// non-successor input value of %r and %arg0 is %0. Therefore, their uses can +/// be replaced with %0, resulting in the following IR: +/// +/// %r = scf.for ... iter_args(%arg0 = %0) -> ... { +/// scf.yield %0 : ... +/// } +/// use(%0) +/// +/// The IR can now be further canonicalized by +/// RemoveDeadRegionBranchOpSuccessorInputs. +struct MakeRegionBranchOpSuccessorInputsDead : public RewritePattern { + MakeRegionBranchOpSuccessorInputsDead(MLIRContext *context, StringRef name, + PatternBenefit benefit = 1) + : RewritePattern(name, benefit, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() && + "isolated-from-above ops are not supported"); + + // Compute the mapping of successor inputs to successor operands. + auto regionBranchOp = cast<RegionBranchOpInterface>(op); + RegionBranchInverseSuccessorMapping inputToOperands; + regionBranchOp.getSuccessorInputOperandMapping(inputToOperands); + + // Try to replace the uses of each successor input one-by-one. + bool changed = false; + for (Value value : inputToOperands.keys()) { + // Nothing to do for successor inputs that are already dead. + if (value.use_empty()) + continue; + // Nothing to do for successor inputs that may have multiple possible + // values. + llvm::SmallDenseSet<Value> possibleValues = + computePossibleValuesOfSuccessorInput(value, inputToOperands); + if (possibleValues.size() != 1) + continue; + assert(*possibleValues.begin() != value && + "successor inputs are supposed to be excluded"); + // Do not replace `value` with the found possible value if doing so would + // violate dominance. Example: + // %r = scf.execute_region ... { + // %a = ... + // scf.yield %a : ... + // } + // In the above example, %a is the only possible value of %r, but it + // cannot be used as a replacement for %r. + if (!isDefinedBefore(regionBranchOp, *possibleValues.begin(), value)) + continue; + rewriter.replaceAllUsesWith(value, *possibleValues.begin()); + changed = true; + } + return success(changed); + } +}; + +/// Lookup a bit vector in the given mapping (DenseMap). If the key was not +/// found, create a new bit vector with the given size and initialize it with +/// false. +template <typename MappingTy, typename KeyTy> +static BitVector &lookupOrCreateBitVector(MappingTy &mapping, KeyTy key, + unsigned size) { + return mapping.try_emplace(key, size, false).first->second; +} + +/// Compute tied successor inputs. Tied successor inputs are successor inputs +/// that come as a set. If you erase one value from a set, you must erase all +/// values from the set. Otherwise, the op would become structurally invalid. +/// Each successor input appears in exactly one set. +/// +/// Example: +/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... { +/// ... +/// } +/// There are two sets: {{%r0, %arg0}, {%r1, %arg1}}. +static llvm::EquivalenceClasses<Value> computeTiedSuccessorInputs( + const RegionBranchSuccessorMapping &operandToInputs) { + llvm::EquivalenceClasses<Value> tiedSuccessorInputs; + for (const auto &[operand, inputs] : operandToInputs) { + assert(!inputs.empty() && "expected non-empty inputs"); + Value firstInput = inputs.front(); + tiedSuccessorInputs.insert(firstInput); + for (Value nextInput : llvm::drop_begin(inputs)) { + // As we explore more successor operand to successor input mappings, + // existing sets may get merged. + tiedSuccessorInputs.unionSets(firstInput, nextInput); + } + } + return tiedSuccessorInputs; +} + +/// Remove dead successor inputs from region branch ops. A successor input is +/// dead if it has no uses. Successor inputs come in sets of tied values: if +/// you remove one value from a set, you must remove all values from the set. +/// Furthermore, successor operands must also be removed. (Op operands are not +/// part of the set, but the set is built based on the successor operand to +/// successor input mapping.) +/// +/// Example 1: +/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... { +/// scf.yield %0, %arg1 : ... +/// } +/// use(%0, %1) +/// +/// There are two sets: {{%r0, %arg0}, {%r1, %arg1}}. All values in the first +/// set are dead, so %arg0 and %r0 can be removed, but not %r1 and %arg1. The +/// resulting IR is as follows: +/// +/// %r1 = scf.for ... iter_args(%arg1 = %1) -> ... { +/// scf.yield %arg1 : ... +/// } +/// use(%0, %1) +/// +/// Example 2: +/// %r0, %r1 = scf.while (%arg0 = %0) { +/// scf.condition(...) %arg0, %arg0 : ... +/// } do { +/// ^bb0(%arg1: ..., %arg2: ...): +/// scf.yield %arg1 : ... +/// } +/// There are three sets: {{%r0, %arg1}, {%r1, %arg2}, {%r0}}. +/// +/// Example 3: +/// %r1, %r2 = scf.if ... { +/// scf.yield %0, %1 : ... +/// } else { +/// scf.yield %2, %3 : ... +/// } +/// There are two sets: {{%r1}, {%r2}}. Each set has one value, so there each +/// value can be removed independently of the other values. +struct RemoveDeadRegionBranchOpSuccessorInputs : public RewritePattern { + RemoveDeadRegionBranchOpSuccessorInputs(MLIRContext *context, StringRef name, + PatternBenefit benefit = 1) + : RewritePattern(name, benefit, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() && + "isolated-from-above ops are not supported"); + + // Compute tied values: values that must come as a set. If you remove one, + // you must remove all. If a successor op operand is forwarded to two + // successor inputs %a and %b, both %a and %b are in the same set. + auto regionBranchOp = cast<RegionBranchOpInterface>(op); + RegionBranchSuccessorMapping operandToInputs; + regionBranchOp.getSuccessorOperandInputMapping(operandToInputs); + llvm::EquivalenceClasses<Value> tiedSuccessorInputs = + computeTiedSuccessorInputs(operandToInputs); + + // Determine which values to remove and group them by block and operation. + SmallVector<Value> valuesToRemove; + DenseMap<Block *, BitVector> blockArgsToRemove; + DenseMap<Operation *, BitVector> resultsToRemove; + // Iterate over all sets of tied successor inputs. + for (auto it = tiedSuccessorInputs.begin(), e = tiedSuccessorInputs.end(); + it != e; ++it) { + if (!(*it)->isLeader()) + continue; + + // Value can be removed if it is dead and all other tied values are also + // dead. + bool allDead = true; + for (auto memberIt = tiedSuccessorInputs.member_begin(**it); + memberIt != tiedSuccessorInputs.member_end(); ++memberIt) { + // Iterate over all values in the set and check their liveness. + if (!memberIt->use_empty()) { + allDead = false; + break; + } + } + if (!allDead) + continue; + + // The entire set is dead. Group values by block and operation to + // simplify removal. + for (auto memberIt = tiedSuccessorInputs.member_begin(**it); + memberIt != tiedSuccessorInputs.member_end(); ++memberIt) { + if (auto arg = dyn_cast<BlockArgument>(*memberIt)) { + // Set blockArgsToRemove[block][arg_number] = true. + BitVector &vector = + lookupOrCreateBitVector(blockArgsToRemove, arg.getOwner(), + arg.getOwner()->getNumArguments()); + vector.set(arg.getArgNumber()); + } else { + // Set resultsToRemove[op][result_number] = true. + OpResult result = cast<OpResult>(*memberIt); + BitVector &vector = + lookupOrCreateBitVector(resultsToRemove, result.getDefiningOp(), + result.getDefiningOp()->getNumResults()); + vector.set(result.getResultNumber()); + } + valuesToRemove.push_back(*memberIt); + } + } + + if (valuesToRemove.empty()) + return rewriter.notifyMatchFailure(op, "no values to remove"); + + // Find operands that must be removed together with the values. + RegionBranchInverseSuccessorMapping inputsToOperands = + invertRegionBranchSuccessorMapping(operandToInputs); + DenseMap<Operation *, llvm::BitVector> operandsToRemove; + for (Value value : valuesToRemove) { + for (OpOperand *operand : inputsToOperands[value]) { + // Set operandsToRemove[op][operand_number] = true. + BitVector &vector = + lookupOrCreateBitVector(operandsToRemove, operand->getOwner(), + operand->getOwner()->getNumOperands()); + vector.set(operand->getOperandNumber()); + } + } + + // Erase operands. + for (auto &pair : operandsToRemove) { + Operation *op = pair.first; + BitVector &operands = pair.second; + rewriter.modifyOpInPlace(op, [&]() { op->eraseOperands(operands); }); + } + + // Erase block arguments. + for (auto &pair : blockArgsToRemove) { + Block *block = pair.first; + BitVector &blockArg = pair.second; + rewriter.modifyOpInPlace(block->getParentOp(), + [&]() { block->eraseArguments(blockArg); }); + } + + // Erase op results. + for (auto [op, resultsToErase] : resultsToRemove) + rewriter.eraseOpResults(op, resultsToErase); + + return success(); + } +}; + +/// Return "true" if the two values are owned by the same operation or block. +static bool haveSameOwner(Value a, Value b) { + void *aOwner, *bOwner; + if (auto arg = dyn_cast<BlockArgument>(a)) + aOwner = arg.getOwner(); + else + aOwner = a.getDefiningOp(); + if (auto arg = dyn_cast<BlockArgument>(b)) + bOwner = arg.getOwner(); + else + bOwner = b.getDefiningOp(); + return aOwner == bOwner; +} + +/// Get the block argument or op result number of the given value. +static unsigned getArgOrResultNumber(Value value) { + if (auto opResult = llvm::dyn_cast<OpResult>(value)) + return opResult.getResultNumber(); + return llvm::cast<BlockArgument>(value).getArgNumber(); +} + +/// Find duplicate successor inputs and make all dead except for one. Two +/// successor inputs are "duplicate" if their corresponding successor operands +/// have the same values. This pattern enables additional canonicalization +/// opportunities for RemoveDeadRegionBranchOpSuccessorInputs. +/// +/// Example: +/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %0) -> ... { +/// use(%arg0, %arg1) +/// ... +/// scf.yield %x, %x : ... +/// } +/// use(%r0, %r1) +/// +/// Operands of successor input %r0: [%0, %x] +/// Operands of successor input %r1: [%0, %x] ==> DUPLICATE! +/// Replace %r1 with %r0. +/// +/// Operands of successor input %arg0: [%0, %x] +/// Operands of successor input %arg1: [%0, %x] ==> DUPLICATE! +/// Replace %arg1 with %arg0. (We have to make sure that we make same decision +/// as for the other tied successor inputs above. Otherwise, a set of tied +/// successor inputs may not become entirely dead.) +/// +/// The resulting IR is as follows: +/// %r1, %r2 = scf.for ... iter_args(%arg0 = %0, %arg1 = %0) -> ... { +/// use(%arg0, %arg0) +/// ... +/// scf.yield %x, %x : ... +/// } +/// use(%r0, %r0) // Note: We don't want use(%r1, %r1), which is also correct, +/// // but does not help with further canonicalizations. +struct RemoveDuplicateSuccessorInputUses : public RewritePattern { + RemoveDuplicateSuccessorInputUses(MLIRContext *context, StringRef name, + PatternBenefit benefit = 1) + : RewritePattern(name, benefit, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() && + "isolated-from-above ops are not supported"); + + // Collect all successor inputs and sort them. When dropping the uses of a + // successor input, we'd like to also drop the uses of the same tied + // successor inputs. Otherwise, a set of tied successor inputs may not + // become entirely dead, which is required for + // RemoveDeadRegionBranchOpSuccessorInputs to be able to erase them. + // (Sorting is not required for correctness.) + auto regionBranchOp = cast<RegionBranchOpInterface>(op); + RegionBranchInverseSuccessorMapping inputsToOperands; + regionBranchOp.getSuccessorInputOperandMapping(inputsToOperands); + SmallVector<Value> inputs = llvm::to_vector(inputsToOperands.keys()); + llvm::sort(inputs, [](Value a, Value b) { + return getArgOrResultNumber(a) < getArgOrResultNumber(b); + }); + + // Check every distinct pair of successor inputs for duplicates. Replace + // `input2` with `input1` if they are duplicates. + bool changed = false; + unsigned numInputs = inputs.size(); + for (auto i : llvm::seq<unsigned>(0, numInputs)) { + Value input1 = inputs[i]; + for (auto j : llvm::seq<unsigned>(i + 1, numInputs)) { + Value input2 = inputs[j]; + // Nothing to do if input2 is already dead. + if (input2.use_empty()) + continue; + // Replace only values that belong to the same block / operation. + // This implies that the two values are either both block arguments or + // both op results. + if (!haveSameOwner(input1, input2)) + continue; + + // Gather the predecessor value for each predecessor (region branch + // point). The two inputs are duplicates if each predecessor forwards + // the same value. + DenseMap<Operation *, Value> operands1, operands2; + for (OpOperand *operand : inputsToOperands[input1]) { + assert(!operands1.contains(operand->getOwner())); + operands1[operand->getOwner()] = operand->get(); + } + for (OpOperand *operand : inputsToOperands[input2]) { + assert(!operands2.contains(operand->getOwner())); + operands2[operand->getOwner()] = operand->get(); + } + if (operands1 == operands2) { + rewriter.replaceAllUsesWith(input2, input1); + changed = true; + } + } + } + return success(changed); + } +}; +} // namespace + +void mlir::populateRegionBranchOpInterfaceCanonicalizationPatterns( + RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit) { + patterns.add<MakeRegionBranchOpSuccessorInputsDead, + RemoveDuplicateSuccessorInputUses, + RemoveDeadRegionBranchOpSuccessorInputs>(patterns.getContext(), + opName, benefit); +} diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index d5d0aee3bbe25..11dc4f04af32e 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1071,17 +1071,17 @@ func.func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i3 // CHECK: %[[ZERO:.*]] = arith.constant dense<0> // CHECK: %[[ONE:.*]] = arith.constant dense<1> // CHECK: %[[CST42:.*]] = arith.constant dense<42> -// CHECK: %[[WHILE:.*]]:3 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]], %[[ARG3:.*]] = %[[ONE]]) +// CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]]) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) // CHECK: arith.cmpi slt, %[[ARG0]], %{{.*}} // CHECK: tensor.extract %{{.*}}[] -// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG2]], %[[ARG3]] +// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG2]] // CHECK: } do { -// CHECK: ^{{.*}}(%[[ARG0:.*]]: tensor<i32>, %[[ARG2:.*]]: tensor<i32>, %[[ARG3:.*]]: tensor<i32>): +// CHECK: ^{{.*}}(%[[ARG0:.*]]: tensor<i32>, %[[ARG2:.*]]: tensor<i32>): // CHECK: %[[VAL0:.*]] = arith.addi %[[ARG0]], %[[FUNC_ARG0]] -// CHECK: %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG3]] -// CHECK: scf.yield %[[VAL0]], %[[VAL1]], %[[VAL1]] +// CHECK: %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG2]] +// CHECK: scf.yield %[[VAL0]], %[[VAL1]] // CHECK: } -// CHECK: return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#2, %[[ZERO]] +// CHECK: return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#1, %[[ZERO]] // CHECK-LABEL: @while_loop_invariant_argument_different_order func.func @while_loop_invariant_argument_different_order(%arg : tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) { @@ -1736,11 +1736,11 @@ module { // Test case with multiple scf.yield ops with at least one different operand, then no change. -// CHECK: %[[VAL_3:.*]]:2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline { +// CHECK: %[[VAL_3:.*]] = scf.execute_region -> memref<1x120xui8> no_inline { // CHECK: ^bb1: -// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8> +// CHECK: scf.yield %{{.*}} : memref<1x120xui8> // CHECK: ^bb2: -// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8> +// CHECK: scf.yield %{{.*}} : memref<1x120xui8> // CHECK: } module { @@ -2178,16 +2178,14 @@ func.func @scf_for_all_step_size_0() { // CHECK-SAME: %[[arg0:.*]]: index // CHECK-DAG: %[[c10:.*]] = arith.constant 10 // CHECK-DAG: %[[c11:.*]] = arith.constant 11 -// CHECK: %[[switch:.*]] = scf.index_switch %[[arg0]] -> index +// CHECK: scf.index_switch %[[arg0]] // CHECK: case 1 { // CHECK: memref.store %[[c10]] -// CHECK: scf.yield %[[arg0]] : index // CHECK: } // CHECK: default { // CHECK: memref.store %[[c11]] -// CHECK: scf.yield %[[arg0]] : index // CHECK: } -// CHECK: return %[[switch]] +// CHECK: return %[[arg0]] func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index { %non_live, %live = scf.index_switch %arg0 -> i32, index case 1 { diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 2584573c8b4dc..ae83eac0c376f 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -250,13 +250,13 @@ func.func @main() -> (i32, i32) { // CHECK-NEXT: } // CHECK-CANONICALIZE: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 { -// CHECK-CANONICALIZE-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) { +// CHECK-CANONICALIZE: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) { // CHECK-CANONICALIZE-NEXT: %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]] -// CHECK-CANONICALIZE-NEXT: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32 +// CHECK-CANONICALIZE: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32 // CHECK-CANONICALIZE-NEXT: } do { // CHECK-CANONICALIZE-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32): // CHECK-CANONICALIZE-NEXT: %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]] -// CHECK-CANONICALIZE-NEXT: scf.yield %[[live_1]] : i32 +// CHECK-CANONICALIZE: scf.yield %[[live_1]] : i32 // CHECK-CANONICALIZE-NEXT: } // CHECK-CANONICALIZE-NEXT: return %[[live_and_non_live]]#0 // CHECK-CANONICALIZE-NEXT: } @@ -306,7 +306,7 @@ func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_o // CHECK-CANONICALIZE: func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 { // CHECK-CANONICALIZE-NEXT: %[[c0:.*]] = arith.constant 0 // CHECK-CANONICALIZE-NEXT: %[[c1:.*]] = arith.constant 1 -// CHECK-CANONICALIZE-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) { +// CHECK-CANONICALIZE: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) { // CHECK-CANONICALIZE-NEXT: func.call @identity() : () -> () // CHECK-CANONICALIZE-NEXT: scf.condition(%[[arg2]]) %[[arg3]], %[[arg4]] : i32, i32 // CHECK-CANONICALIZE-NEXT: } do { _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
