llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> Add a new canonicalization pattern that inlines the body of acyclic `RegionBranchOpInterface` ops. This pattern is a generalization and replacement for the following existing patterns: * `SingleBlockExecuteInliner`: inlines `scf.execute_region` ops with a single block. * `SimplifyTrivialLoops`: inlines / folds away `scf.for` ops with 0 or 1 iterations. * `RemoveStaticCondition`: inlines `scf.if` ops with a static condition. * `FoldConstantCase`: inlines `scf.index_switch` ops with a constant operand. Additionally, this new pattern is also enabled for `scf.while` ops. Loops with `scf.condition(%false)` are now also inlined. (New test case added.) The new pattern looks for region branch ops with a single acyclic path through the operation (starting from and ending at "parent"). All regions on that path can be inlined into the enclosing block. Depends on #<!-- -->177116. --- Patch is 27.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/176641.diff 6 Files Affected: - (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.h (+39) - (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+26-145) - (modified) mlir/lib/Interfaces/ControlFlowInterfaces.cpp (+233) - (modified) mlir/test/Dialect/Arith/int-range-interface.mlir (+4-2) - (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+20) - (modified) mlir/test/Dialect/SCF/one-shot-bufferize.mlir (+4-2) ``````````diff diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h index d764089f5ccc8..a76dce6f2ffc5 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -320,6 +320,45 @@ Region *getEnclosingRepetitiveRegion(Value value); void populateRegionBranchOpInterfaceCanonicalizationPatterns( RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit = 1); +/// Helper function for the region branch op inlining pattern that builds +/// replacement values for non-successor-input values. +using NonSuccessorInputReplacementBuilderFn = + std::function<Value(OpBuilder &, Location, Value)>; +/// Helper function for the region branch op inlining pattern that checks if the +/// pattern is applicable to the given operation. +using PatternMatcherFn = std::function<LogicalResult(Operation *)>; + +namespace detail { +/// Default implementation of the non-successor-input replacement builder +/// function. This default implemention assumes that all block arguments and +/// op results are successor inputs. +static inline Value defaultReplBuilderFn(OpBuilder &builder, Location loc, + Value value) { + llvm_unreachable("defaultReplBuilderFn not implemented"); +} + +/// Default implementation of the pattern matcher function. +static inline LogicalResult defaultMatcherFn(Operation *op) { + return success(); +} +} // namespace detail + +/// Populate a pattern that inlines the body of region branch ops when there is +/// a single acyclic path through the region branch op, starting from "parent" +/// and ending at "parent". For details, refer to the documentation of the +/// pattern. +/// +/// `replBuilderFn` is a function that builds replacement values for +/// non-successor-input values of the region branch op. `matcherFn` is a +/// function that checks if the pattern is applicable to the given operation. +/// Both functions are optional. +void populateRegionBranchOpInterfaceInliningPattern( + RewritePatternSet &patterns, StringRef opName, + NonSuccessorInputReplacementBuilderFn replBuilderFn = + detail::defaultReplBuilderFn, + PatternMatcherFn matcherFn = detail::defaultMatcherFn, + PatternBenefit benefit = 1); + //===----------------------------------------------------------------------===// // ControlFlow Traits //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 86e66dbaf6171..2ebece4bdedb7 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -132,19 +132,6 @@ std::optional<llvm::APSInt> mlir::scf::computeUbMinusLb(Value lb, Value ub, // ExecuteRegionOp //===----------------------------------------------------------------------===// -/// Replaces the given op with the contents of the given single-block region, -/// using the operands of the block terminator to replace operation results. -static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, - Region ®ion, ValueRange blockArgs = {}) { - assert(region.hasOneBlock() && "expected single-block region"); - Block *block = ®ion.front(); - Operation *terminator = block->getTerminator(); - ValueRange results = terminator->getOperands(); - rewriter.inlineBlockBefore(block, op, blockArgs); - rewriter.replaceOp(op, results); - rewriter.eraseOp(terminator); -} - /// /// (ssa-id `=`)? `execute_region` `->` function-result-type `{` /// block+ @@ -192,32 +179,6 @@ LogicalResult ExecuteRegionOp::verify() { return success(); } -// Inline an ExecuteRegionOp if it only contains one block. -// "test.foo"() : () -> () -// %v = scf.execute_region -> i64 { -// %x = "test.val"() : () -> i64 -// scf.yield %x : i64 -// } -// "test.bar"(%v) : (i64) -> () -// -// becomes -// -// "test.foo"() : () -> () -// %x = "test.val"() : () -> i64 -// "test.bar"(%x) : (i64) -> () -// -struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { - using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(ExecuteRegionOp op, - PatternRewriter &rewriter) const override { - if (!op.getRegion().hasOneBlock() || op.getNoInline()) - return failure(); - replaceOpWithRegion(rewriter, op, op.getRegion()); - return success(); - } -}; - // Inline an ExecuteRegionOp if its parent can contain multiple blocks. // TODO generalize the conditions for operations which can be inlined into. // func @func_execute_region_elim() { @@ -293,9 +254,15 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context); + results.add<MultiBlockExecuteInliner>(context); populateRegionBranchOpInterfaceCanonicalizationPatterns( results, ExecuteRegionOp::getOperationName()); + // Inline ops with a single block that are not marked as "no_inline". + populateRegionBranchOpInterfaceInliningPattern( + results, ExecuteRegionOp::getOperationName(), + mlir::detail::defaultReplBuilderFn, [](Operation *op) { + return failure(cast<ExecuteRegionOp>(op).getNoInline()); + }); } void ExecuteRegionOp::getSuccessorRegions( @@ -962,54 +929,6 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, } namespace { -/// Rewriting pattern that erases loops that are known not to iterate, replaces -/// single-iteration loops with their bodies, and removes empty loops that -/// iterate at least once and only return values defined outside of the loop. -struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> { - using OpRewritePattern<ForOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(ForOp op, - PatternRewriter &rewriter) const override { - std::optional<APInt> tripCount = op.getStaticTripCount(); - if (!tripCount.has_value()) - return rewriter.notifyMatchFailure(op, - "can't compute constant trip count"); - - if (tripCount->isZero()) { - LDBG() << "SimplifyTrivialLoops tripCount is 0 for loop " - << OpWithFlags(op, OpPrintingFlags().skipRegions()); - rewriter.replaceOp(op, op.getInitArgs()); - return success(); - } - - if (tripCount->getSExtValue() == 1) { - LDBG() << "SimplifyTrivialLoops tripCount is 1 for loop " - << OpWithFlags(op, OpPrintingFlags().skipRegions()); - SmallVector<Value, 4> blockArgs; - blockArgs.reserve(op.getInitArgs().size() + 1); - blockArgs.push_back(op.getLowerBound()); - llvm::append_range(blockArgs, op.getInitArgs()); - replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs); - return success(); - } - - // Now we are left with loops that have more than 1 iterations. - Block &block = op.getRegion().front(); - if (!llvm::hasSingleElement(block)) - return failure(); - // The loop is empty and iterates at least once, if it only returns values - // defined outside of the loop, remove it and replace it with yield values. - if (llvm::any_of(op.getYieldedValues(), - [&](Value v) { return !op.isDefinedOutsideOfLoop(v); })) - return failure(); - LDBG() << "SimplifyTrivialLoops empty body loop allows replacement with " - "yield operands for loop " - << OpWithFlags(op, OpPrintingFlags().skipRegions()); - rewriter.replaceOp(op, op.getYieldedValues()); - return success(); - } -}; - /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for: /// @@ -1072,9 +991,20 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> { void ForOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<SimplifyTrivialLoops, ForOpTensorCastFolder>(context); + results.add<ForOpTensorCastFolder>(context); populateRegionBranchOpInterfaceCanonicalizationPatterns( results, ForOp::getOperationName()); + populateRegionBranchOpInterfaceInliningPattern( + results, ForOp::getOperationName(), + /*replBuilderFn=*/[](OpBuilder &builder, Location loc, Value value) { + // scf.for has only one non-successor input value: the loop induction + // variable. In case of a single acyclic path through the op, the IV can + // be safely replaced with the lower bound. + auto blockArg = cast<BlockArgument>(value); + assert(blockArg.getArgNumber() == 0 && "expected induction variable"); + auto forOp = cast<ForOp>(blockArg.getOwner()->getParentOp()); + return forOp.getLowerBound(); + }); } std::optional<APInt> ForOp::getConstantStep() { @@ -2218,26 +2148,6 @@ void IfOp::getRegionInvocationBounds( } namespace { -struct RemoveStaticCondition : public OpRewritePattern<IfOp> { - using OpRewritePattern<IfOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(IfOp op, - PatternRewriter &rewriter) const override { - BoolAttr condition; - if (!matchPattern(op.getCondition(), m_Constant(&condition))) - return failure(); - - if (condition.getValue()) - replaceOpWithRegion(rewriter, op, op.getThenRegion()); - else if (!op.getElseRegion().empty()) - replaceOpWithRegion(rewriter, op, op.getElseRegion()); - else - rewriter.eraseOp(op); - - return success(); - } -}; - /// Hoist any yielded results whose operands are defined outside /// the if, to a select instruction. struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> { @@ -2788,10 +2698,11 @@ void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add<CombineIfs, CombineNestedIfs, ConditionPropagation, ConvertTrivialIfToSelect, RemoveEmptyElseBranch, - RemoveStaticCondition, ReplaceIfYieldWithConditionOrValue>( - context); + ReplaceIfYieldWithConditionOrValue>(context); populateRegionBranchOpInterfaceCanonicalizationPatterns( results, IfOp::getOperationName()); + populateRegionBranchOpInterfaceInliningPattern(results, + IfOp::getOperationName()); } Block *IfOp::thenBlock() { return &getThenRegion().back(); } @@ -3796,6 +3707,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, WhileMoveIfDown>(context); populateRegionBranchOpInterfaceCanonicalizationPatterns( results, WhileOp::getOperationName()); + populateRegionBranchOpInterfaceInliningPattern(results, + WhileOp::getOperationName()); } //===----------------------------------------------------------------------===// @@ -3942,44 +3855,12 @@ void IndexSwitchOp::getRegionInvocationBounds( bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex); } -struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> { - using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::IndexSwitchOp op, - PatternRewriter &rewriter) const override { - // If `op.getArg()` is a constant, select the region that matches with - // the constant value. Use the default region if no matche is found. - std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg()); - if (!maybeCst.has_value()) - return failure(); - int64_t cst = *maybeCst; - int64_t caseIdx, e = op.getNumCases(); - for (caseIdx = 0; caseIdx < e; ++caseIdx) { - if (cst == op.getCases()[caseIdx]) - break; - } - - Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx] - : op.getDefaultRegion(); - Block &source = r.front(); - Operation *terminator = source.getTerminator(); - SmallVector<Value> results = terminator->getOperands(); - - rewriter.inlineBlockBefore(&source, op); - rewriter.eraseOp(terminator); - // Replace the operation with a potentially empty list of results. - // Fold mechanism doesn't support the case where the result list is empty. - rewriter.replaceOp(op, results); - - return success(); - } -}; - void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<FoldConstantCase>(context); populateRegionBranchOpInterfaceCanonicalizationPatterns( results, IndexSwitchOp::getOperationName()); + populateRegionBranchOpInterfaceInliningPattern( + results, IndexSwitchOp::getOperationName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp index ebf78d8bd60ce..8ed32ddf39a53 100644 --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -9,6 +9,7 @@ #include <utility> #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -1025,6 +1026,230 @@ struct RemoveDuplicateSuccessorInputUses : public RewritePattern { return success(changed); } }; + +/// Given a range of values, return a vector of attributes of the same size, +/// where the i-th attribute is the constant value of the i-th value. If a +/// value is not constant, the corresponding attribute is null. +static SmallVector<Attribute> extractConstants(ValueRange values) { + return llvm::map_to_vector(values, [](Value value) { + Attribute attr; + matchPattern(value, m_Constant(&attr)); + return attr; + }); +} + +/// Return all successor regions when branching from the given region branch +/// point. This helper functions extracts all constant operand values and +/// passes them to the `RegionBranchOpInterface`. +static SmallVector<RegionSuccessor> +getSuccessorRegionsWithAttrs(RegionBranchOpInterface op, + RegionBranchPoint point) { + SmallVector<RegionSuccessor> successors; + if (point.isParent()) { + op.getEntrySuccessorRegions(extractConstants(op->getOperands()), + successors); + return successors; + } + RegionBranchTerminatorOpInterface terminator = + point.getTerminatorPredecessorOrNull(); + terminator.getSuccessorRegions(extractConstants(terminator->getOperands()), + successors); + return successors; +} + +/// Find the single acyclic path through the given region branch op. Return an +/// empty vector if no such path or multiple such paths exist. +/// +/// Example: "scf.if %true" has a single path: parent => then_region => parent +/// +/// Example: "scf.if ???" has multiple paths: +/// (1) parent => then_region => parent +/// (2) parent => else_region => parent +/// +/// Example: "scf.while with scf.condition(%false)" has a single path: +/// parent => before_region => parent +/// +/// Example: "scf.for with 0 iterations" has a single path: parent => parent +/// +/// Note: Each path starts and ends with "parent". The "parent" at the beginning +/// of the path is omitted from the result. +/// +/// Note: This function also returns an "empty" path when a region with multiple +/// blocks was found. +static SmallVector<RegionSuccessor> +computeSingleAcyclicRegionBranchPath(RegionBranchOpInterface op) { + llvm::SmallDenseSet<Region *> visited; + SmallVector<RegionSuccessor> path; + + // Path starts with "parent". + RegionBranchPoint next = RegionBranchPoint::parent(); + do { + SmallVector<RegionSuccessor> successors = + getSuccessorRegionsWithAttrs(op, next); + if (successors.size() != 1) { + // There are multiple region successors. I.e., there are multiple paths + // through the region branch op. + return {}; + } + path.push_back(successors.front()); + if (successors.front().isParent()) { + // Found path that ends with "parent". + return path; + } + Region *region = successors.front().getSuccessor(); + if (!region->hasOneBlock()) { + // Entering a region with multiple blocks. Such regions are not supported + // at the moment. + return {}; + } + if (!visited.insert(region).second) { + // We have already visited this region. I.e., we have found a cycle. + return {}; + } + auto terminator = + dyn_cast<RegionBranchTerminatorOpInterface>(®ion->front().back()); + if (!terminator) { + // Region has no RegionBranchTerminatorOpInterface terminator. E.g., the + // terminator could be a "ub.unreachable" op. Such IR is not supported. + return {}; + } + next = RegionBranchPoint(terminator); + } while (true); + llvm_unreachable("expected to return from loop"); +} + +/// Inline the body of the matched region branch op into the enclosing block if +/// there is exactly one acyclic path through the region branch op, starting +/// from "parent", and if that path ends with "parent". +/// +/// Example: This pattern can inline "scf.for" operations that are guaranteed to +/// have a single iteration, as indicated by the region branch path "parent => +/// region => parent". "scf.for" operations have a non-successor-input: the loop +/// induction variable. Non-successor-input values have op-specific semantics +/// and cannot be reasoned about through the `RegionBranchOpInterface`. A +/// replacement value for non-successor-inputs is injected by the user-specified +/// lambda: in the case of the loop induction variable of an "scf.for", the +/// lower bound of the loop is used as a replacement value. +/// +/// Before pattern application: +/// %r = scf.for %iv = %c5 to %c6 step %c1 iter_args(%arg0 = %0) { +/// %1 = "producer"(%arg0, %iv) +/// scf.yield %1 +/// } +/// "user"(%r) +/// +/// After pattern application: +/// %1 = "producer"(%0, %c5) +/// "user"(%1) +/// +/// This pattern is limited to the following cases: +/// - Only regions with a single block are supported. This could be generalized. +/// - Region branch ops with side effects are not supported. (Recursive side +/// effects are fine.) +/// +/// Note: This pattern queries the region dataflow from the +/// `RegionBranchOpInterface`. Replacement values are for block arguments / op +/// results are determined based on region dataflow. In case of +/// non-successor-inputs (whose values are not modeled by the +/// `RegionBranchOpInterface`), a user-specified lambda is queried. +struct InlineRegionBranchOp : public RewritePattern { + InlineRegionBranchOp(MLIRContext *context, StringRef name, + NonSuccessorInputReplacementBuilderFn replBuilderFn, + PatternMatcherFn matcherFn, PatternBenefit benefit = 1) + : RewritePattern(name, benefit, context), replBuilderFn(replBuilderFn), + matcherFn(matcherFn) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // Check if the pattern is applicable to the given operation. + if (failed(matcherFn(op))) + return rewriter.notifyMatchFailure(op, "pattern not applicable"); + + // Patterns without recursive memory effects could have side effects, so + // it is not safe to fold such ops away. + if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) + return rewriter.notifyMatchFailure( + op, "pattern not applicable to ops without recursive memory effects"); + + // Find the single acyclic path through the region branch op. + auto regionBranchOp = cast<RegionBranchOpInterface>(op); + SmallVector<RegionSuccessor> path = + computeSingleAcyclicRegionBranchPath(regionBranchOp); + if (path.empty()) + return rewriter.notifyMatchFailure( ... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/176641 _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
