llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Instead of op-specific cleanup patterns for region branch ops to remove dead 
values etc., add a set of patterns that can handle all 
`RegionBranchOpInterface` ops. These patterns are enabled only for selected SCF 
dialect ops at the moment. (It is not possible to register canoncalization 
patterns for op interfaces.)

This commit removes many similar canonicalization patterns from the SCF 
dialect. The newly added canonicalization patterns allow users to get the same 
canonicalizations for free for their own ops.

Implementation outline: This commit adds 3 canonicalization patterns.
* `MakeRegionBranchOpSuccessorInputsDead`: Remove uses of successor inputs, by 
swapping them for successor operand values.
* `RemoveDuplicateSuccessorInputUses`: Remove uses of successor inputs that are 
duplicates. (Similar to `WhileRemoveDuplicatedResults` in the SCF dialect.)
* `RemoveDeadRegionBranchOpSuccessorInputs`: Remove dead successor inputs if 
all of their "tied" successor inputs are also dead. (Similar to 
`WhileUnusedResult` in the SCF dialect.)

Depends on #<!-- -->173505.

---

Patch is 64.33 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/174094.diff


6 Files Affected:

- (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.h (+9) 
- (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.td (+5) 
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+17-813) 
- (modified) mlir/lib/Interfaces/ControlFlowInterfaces.cpp (+467) 
- (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+11-13) 
- (modified) mlir/test/Transforms/remove-dead-values.mlir (+4-4) 


``````````diff
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h 
b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 566f4b8fadb5d..ea85b2d1b5cb6 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -16,6 +16,7 @@
 
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
 #include "llvm/ADT/PointerUnion.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/raw_ostream.h"
@@ -188,6 +189,8 @@ LogicalResult verifyTypesAlongControlFlowEdges(Operation 
*op);
 /// possible successors.) Operands that not forwarded at all are not present in
 /// the mapping.
 using RegionBranchSuccessorMapping = DenseMap<OpOperand *, SmallVector<Value>>;
+using RegionBranchInverseSuccessorMapping =
+    DenseMap<Value, SmallVector<OpOperand *>>;
 
 /// This class represents a successor of a region. A region successor can 
either
 /// be another region, or the parent operation. If the successor is a region,
@@ -350,6 +353,12 @@ Region *getEnclosingRepetitiveRegion(Operation *op);
 /// exists.
 Region *getEnclosingRepetitiveRegion(Value value);
 
+/// Populate canonicalization patterns that simplify successor operands/inputs
+/// of region branch operations. Only operations with the given name are
+/// matched.
+void populateRegionBranchOpInterfaceCanonicalizationPatterns(
+    RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit = 1);
+
 
//===----------------------------------------------------------------------===//
 // ControlFlow Traits
 
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td 
b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 2e654ba04ffe5..70aed9e1e11c6 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -355,6 +355,11 @@ def RegionBranchOpInterface : 
OpInterface<"RegionBranchOpInterface"> {
         ::mlir::RegionBranchSuccessorMapping &mapping,
         std::optional<::mlir::RegionBranchPoint> src = std::nullopt);
 
+    /// Build a mapping from successor inputs to successor operands. This is
+    /// the same as "getSuccessorOperandInputMapping", but inverted.
+    void getSuccessorInputOperandMapping(
+        ::mlir::RegionBranchInverseSuccessorMapping &mapping);
+
     /// Return all possible region branch points: the region branch op itself
     /// and all region branch terminators.
     ::llvm::SmallVector<::mlir::RegionBranchPoint> getAllRegionBranchPoints();
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 8803a6d136f7a..95a854b655a53 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -291,102 +291,11 @@ struct MultiBlockExecuteInliner : public 
OpRewritePattern<ExecuteRegionOp> {
   }
 };
 
-// Pattern to eliminate ExecuteRegionOp results which forward external
-// values from the region. In case there are multiple yield operations,
-// all of them must have the same operands in order for the pattern to be
-// applicable.
-struct ExecuteRegionForwardingEliminator
-    : public OpRewritePattern<ExecuteRegionOp> {
-  using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExecuteRegionOp op,
-                                PatternRewriter &rewriter) const override {
-    if (op.getNumResults() == 0)
-      return failure();
-
-    SmallVector<Operation *> yieldOps;
-    for (Block &block : op.getRegion()) {
-      if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator()))
-        yieldOps.push_back(yield.getOperation());
-    }
-
-    if (yieldOps.empty())
-      return failure();
-
-    // Check if all yield operations have the same operands.
-    auto yieldOpsOperands = yieldOps[0]->getOperands();
-    for (auto *yieldOp : yieldOps) {
-      if (yieldOp->getOperands() != yieldOpsOperands)
-        return failure();
-    }
-
-    SmallVector<Value> externalValues;
-    SmallVector<Value> internalValues;
-    SmallVector<Value> opResultsToReplaceWithExternalValues;
-    SmallVector<Value> opResultsToKeep;
-    for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
-      if (isValueFromInsideRegion(yieldedValue, op)) {
-        internalValues.push_back(yieldedValue);
-        opResultsToKeep.push_back(op.getResult(index));
-      } else {
-        externalValues.push_back(yieldedValue);
-        opResultsToReplaceWithExternalValues.push_back(op.getResult(index));
-      }
-    }
-    // No yielded external values - nothing to do.
-    if (externalValues.empty())
-      return failure();
-
-    // There are yielded external values - create a new execute_region 
returning
-    // just the internal values.
-    SmallVector<Type> resultTypes;
-    for (Value value : internalValues)
-      resultTypes.push_back(value.getType());
-    auto newOp =
-        ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes));
-    newOp->setAttrs(op->getAttrs());
-
-    // Move old op's region to the new operation.
-    rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
-                                newOp.getRegion().end());
-
-    // Replace all yield operations with a new yield operation with updated
-    // results. scf.execute_region must have at least one yield operation.
-    for (auto *yieldOp : yieldOps) {
-      rewriter.setInsertionPoint(yieldOp);
-      rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp,
-                                                ValueRange(internalValues));
-    }
-
-    // Replace the old operation with the external values directly.
-    rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues,
-                                externalValues);
-    // Replace the old operation's remaining results with the new operation's
-    // results.
-    rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults());
-    rewriter.eraseOp(op);
-    return success();
-  }
-
-private:
-  bool isValueFromInsideRegion(Value value,
-                               ExecuteRegionOp executeRegionOp) const {
-    // Check if the value is defined within the execute_region
-    if (Operation *defOp = value.getDefiningOp())
-      return &executeRegionOp.getRegion() == defOp->getParentRegion();
-
-    // If it's a block argument, check if it's from within the region
-    if (BlockArgument blockArg = dyn_cast<BlockArgument>(value))
-      return &executeRegionOp.getRegion() == blockArg.getParentRegion();
-
-    return false; // Value is from outside the region
-  }
-};
-
 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner,
-              ExecuteRegionForwardingEliminator>(context);
+  results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
+  populateRegionBranchOpInterfaceCanonicalizationPatterns(
+      results, ExecuteRegionOp::getOperationName());
 }
 
 void ExecuteRegionOp::getSuccessorRegions(
@@ -989,146 +898,6 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase 
&rewriter, scf::ForOp forOp,
 }
 
 namespace {
-// Fold away ForOp iter arguments when:
-// 1) The op yields the iter arguments.
-// 2) The argument's corresponding outer region iterators (inputs) are yielded.
-// 3) The iter arguments have no use and the corresponding (operation) results
-// have no use.
-//
-// These arguments must be defined outside of the ForOp region and can just be
-// forwarded after simplifying the op inits, yields and returns.
-//
-// The implementation uses `inlineBlockBefore` to steal the content of the
-// original ForOp and avoid cloning.
-struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
-  using OpRewritePattern<scf::ForOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(scf::ForOp forOp,
-                                PatternRewriter &rewriter) const final {
-    bool canonicalize = false;
-
-    // An internal flat vector of block transfer
-    // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
-    // transformed block argument mappings. This plays the role of a
-    // IRMapping for the particular use case of calling into
-    // `inlineBlockBefore`.
-    int64_t numResults = forOp.getNumResults();
-    SmallVector<bool, 4> keepMask;
-    keepMask.reserve(numResults);
-    SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
-        newResultValues;
-    newBlockTransferArgs.reserve(1 + numResults);
-    newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
-    newIterArgs.reserve(forOp.getInitArgs().size());
-    newYieldValues.reserve(numResults);
-    newResultValues.reserve(numResults);
-    DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
-    for (auto [init, arg, result, yielded] :
-         llvm::zip(forOp.getInitArgs(),       // iter from outside
-                   forOp.getRegionIterArgs(), // iter inside region
-                   forOp.getResults(),        // op results
-                   forOp.getYieldedValues()   // iter yield
-                   )) {
-      // Forwarded is `true` when:
-      // 1) The region `iter` argument is yielded.
-      // 2) The region `iter` argument the corresponding input is yielded.
-      // 3) The region `iter` argument has no use, and the corresponding op
-      // result has no use.
-      bool forwarded = (arg == yielded) || (init == yielded) ||
-                       (arg.use_empty() && result.use_empty());
-      if (forwarded) {
-        canonicalize = true;
-        keepMask.push_back(false);
-        newBlockTransferArgs.push_back(init);
-        newResultValues.push_back(init);
-        continue;
-      }
-
-      // Check if a previous kept argument always has the same values for init
-      // and yielded values.
-      if (auto it = initYieldToArg.find({init, yielded});
-          it != initYieldToArg.end()) {
-        canonicalize = true;
-        keepMask.push_back(false);
-        auto [sameArg, sameResult] = it->second;
-        rewriter.replaceAllUsesWith(arg, sameArg);
-        rewriter.replaceAllUsesWith(result, sameResult);
-        // The replacement value doesn't matter because there are no uses.
-        newBlockTransferArgs.push_back(init);
-        newResultValues.push_back(init);
-        continue;
-      }
-
-      // This value is kept.
-      initYieldToArg.insert({{init, yielded}, {arg, result}});
-      keepMask.push_back(true);
-      newIterArgs.push_back(init);
-      newYieldValues.push_back(yielded);
-      newBlockTransferArgs.push_back(Value()); // placeholder with null value
-      newResultValues.push_back(Value());      // placeholder with null value
-    }
-
-    if (!canonicalize)
-      return failure();
-
-    scf::ForOp newForOp =
-        scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
-                           forOp.getUpperBound(), forOp.getStep(), newIterArgs,
-                           /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
-    newForOp->setAttrs(forOp->getAttrs());
-    Block &newBlock = newForOp.getRegion().front();
-
-    // Replace the null placeholders with newly constructed values.
-    newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
-    for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
-         idx != e; ++idx) {
-      Value &blockTransferArg = newBlockTransferArgs[1 + idx];
-      Value &newResultVal = newResultValues[idx];
-      assert((blockTransferArg && newResultVal) ||
-             (!blockTransferArg && !newResultVal));
-      if (!blockTransferArg) {
-        blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
-        newResultVal = newForOp.getResult(collapsedIdx++);
-      }
-    }
-
-    Block &oldBlock = forOp.getRegion().front();
-    assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
-           "unexpected argument size mismatch");
-
-    // No results case: the scf::ForOp builder already created a zero
-    // result terminator. Merge before this terminator and just get rid of the
-    // original terminator that has been merged in.
-    if (newIterArgs.empty()) {
-      auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
-      rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
-      rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
-      rewriter.replaceOp(forOp, newResultValues);
-      return success();
-    }
-
-    // No terminator case: merge and rewrite the merged terminator.
-    auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
-      OpBuilder::InsertionGuard g(rewriter);
-      rewriter.setInsertionPoint(mergedTerminator);
-      SmallVector<Value, 4> filteredOperands;
-      filteredOperands.reserve(newResultValues.size());
-      for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
-        if (keepMask[idx])
-          filteredOperands.push_back(mergedTerminator.getOperand(idx));
-      scf::YieldOp::create(rewriter, mergedTerminator.getLoc(),
-                           filteredOperands);
-    };
-
-    rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
-    auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
-    cloneFilteredTerminator(mergedYieldOp);
-    rewriter.eraseOp(mergedYieldOp);
-    rewriter.replaceOp(forOp, newResultValues);
-    return success();
-  }
-};
-
 /// Rewriting pattern that erases loops that are known not to iterate, replaces
 /// single-iteration loops with their bodies, and removes empty loops that
 /// iterate at least once and only return values defined outside of the loop.
@@ -1235,13 +1004,13 @@ struct ForOpTensorCastFolder : public 
OpRewritePattern<ForOp> {
     return failure();
   }
 };
-
 } // namespace
 
 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, 
ForOpTensorCastFolder>(
-      context);
+  results.add<SimplifyTrivialLoops, ForOpTensorCastFolder>(context);
+  populateRegionBranchOpInterfaceCanonicalizationPatterns(
+      results, ForOp::getOperationName());
 }
 
 std::optional<APInt> ForOp::getConstantStep() {
@@ -2378,35 +2147,6 @@ void IfOp::getRegionInvocationBounds(
 }
 
 namespace {
-// Pattern to remove unused IfOp results.
-struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
-  using OpRewritePattern<IfOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(IfOp op,
-                                PatternRewriter &rewriter) const override {
-    // Compute the list of unused results.
-    BitVector toErase(op.getNumResults(), false);
-    for (auto [idx, result] : llvm::enumerate(op.getResults()))
-      if (result.use_empty())
-        toErase[idx] = true;
-    if (toErase.none())
-      return rewriter.notifyMatchFailure(op, "no results to erase");
-
-    // Erase results.
-    auto newOp = cast<scf::IfOp>(rewriter.eraseOpResults(op, toErase));
-
-    // Erase operands.
-    rewriter.modifyOpInPlace(newOp.thenYield(), [&]() {
-      newOp.thenYield()->eraseOperands(toErase);
-    });
-    rewriter.modifyOpInPlace(newOp.elseYield(), [&]() {
-      newOp.elseYield()->eraseOperands(toErase);
-    });
-
-    return success();
-  }
-};
-
 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
   using OpRewritePattern<IfOp>::OpRewritePattern;
 
@@ -2977,8 +2717,10 @@ void IfOp::getCanonicalizationPatterns(RewritePatternSet 
&results,
                                        MLIRContext *context) {
   results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
               ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
-              RemoveStaticCondition, RemoveUnusedResults,
-              ReplaceIfYieldWithConditionOrValue>(context);
+              RemoveStaticCondition, ReplaceIfYieldWithConditionOrValue>(
+      context);
+  populateRegionBranchOpInterfaceCanonicalizationPatterns(
+      results, IfOp::getOperationName());
 }
 
 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
@@ -3816,390 +3558,6 @@ struct WhileConditionTruth : public 
OpRewritePattern<WhileOp> {
   }
 };
 
-/// Remove loop invariant arguments from `before` block of scf.while.
-/// A before block argument is considered loop invariant if :-
-///   1. i-th yield operand is equal to the i-th while operand.
-///   2. i-th yield operand is k-th after block argument which is (k+1)-th
-///      condition operand AND this (k+1)-th condition operand is equal to i-th
-///      iter argument/while operand.
-/// For the arguments which are removed, their uses inside scf.while
-/// are replaced with their corresponding initial value.
-///
-/// Eg:
-///    INPUT :-
-///    %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
-///                                     ..., %argN_before = %N)
-///           {
-///                ...
-///                scf.condition(%cond) %arg1_before, %arg0_before,
-///                                     %arg2_before, %arg0_before, ...
-///           } do {
-///             ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
-///                  ..., %argK_after):
-///                ...
-///                scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
-///           }
-///
-///    OUTPUT :-
-///    %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
-///                                     %N)
-///           {
-///                ...
-///                scf.condition(%cond) %b, %a, %arg2_before, %a, ...
-///           } do {
-///             ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
-///                  ..., %argK_after):
-///                ...
-///                scf.yield %arg1_after, ..., %argN
-///           }
-///
-///    EXPLANATION:
-///      We iterate over each yield operand.
-///        1. 0-th yield operand %arg0_after_2 is 4-th condition operand
-///           %arg0_before, which in turn is the 0-th iter argument. So we
-///           remove 0-th before block argument and yield operand, and replace
-///           all uses of the 0-th before block argument with its initial value
-///           %a.
-///        2. 1-th yield operand %b is equal to the 1-th iter arg's initial
-///           value. So we remove this operand and the corresponding before
-///           block argument and replace all uses of 1-th before block argument
-///           with %b.
-struct RemoveLoopInvariantArgsFromBeforeBlock
-    : public OpRewritePattern<WhileOp> {
-  using OpRewritePattern<WhileOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(WhileOp op,
-                                PatternRewriter &rewriter) const override {
-    Block &afterBlock = *op.getAfterBody();
-    Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
-    ConditionOp condOp = op.getConditionOp();
-    OperandRange condOpArgs = condOp.getArgs();
-    Operation *yieldOp = afterBlock.getTerminator();
-    ValueRange yieldOpArgs = yieldOp->getOperands();
-
-    bool canSimplify = false;
-    for (const auto &it :
-         llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
-      auto index = static_cast<unsigned>(it.index());
-      auto [initVal, yieldOpArg] = it.value();
-      // If i-th yield operand is equal to the i-th operand of the scf.while,
-      // the i-th before block argument is a loop invariant.
-      if (yieldOpArg == initVal) {
-        canSimplify = true;
-        break;
-      }
-      // If the i-th yield operand is k-th after block argument, then we check
-      // if the (k+1)-th condition op operand is equal to either the i-th 
before
-      // block argument or the initial value of i-th before block argument. If
-      // the comparison results `true`, i-th before block argument is a loop
-      // invariant.
-      auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
-      if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
-        Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
-        if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
-          canSimplify = true;
-          break;
-        }
- ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/174094
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to