https://github.com/matthias-springer updated 
https://github.com/llvm/llvm-project/pull/174094

>From 8b2bcb7b6652272bc614d37ca2305000797b2b3f Mon Sep 17 00:00:00 2001
From: Matthias Springer <[email protected]>
Date: Wed, 31 Dec 2025 14:07:51 +0000
Subject: [PATCH] [mlir][draft] Consolidate patterns into
 RegionBranchOpInterface patterns

fix some tests
---
 .../SparseTensor/IR/SparseTensorOps.td        |    7 +-
 .../mlir/Interfaces/ControlFlowInterfaces.h   |    2 +
 .../mlir/Interfaces/ControlFlowInterfaces.td  |    9 +
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 1134 ++++-------------
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp |   39 +
 mlir/test/Dialect/SCF/canonicalize.mlir       |   24 +-
 .../Dialect/SparseTensor/sparse_kernels.mlir  |   16 +-
 .../test/Dialect/SparseTensor/sparse_out.mlir |   34 +-
 .../Vector/vector-warp-distribute.mlir        |    6 +-
 mlir/test/Transforms/remove-dead-values.mlir  |    8 +-
 10 files changed, 377 insertions(+), 902 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td 
b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index a61d90a0c39b1..f41b3694d9c79 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1304,9 +1304,10 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", 
[Pure, SameOperandsAndResu
   let hasVerifier = 1;
 }
 
-def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
-    ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
-                 "ForeachOp", "IterateOp", "CoIterateOp"]>]> {
+def SparseTensor_YieldOp : SparseTensor_Op<"yield",
+    [Pure, Terminator, ReturnLike,
+     ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
+                  "ForeachOp", "IterateOp", "CoIterateOp"]>]> {
   let summary = "Yield from sparse_tensor set-like operations";
   let description = [{
       Yields a value from within a `binary`, `unary`, `reduce`,
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h 
b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 566f4b8fadb5d..a7565f9f7bb78 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -188,6 +188,8 @@ LogicalResult verifyTypesAlongControlFlowEdges(Operation 
*op);
 /// possible successors.) Operands that not forwarded at all are not present in
 /// the mapping.
 using RegionBranchSuccessorMapping = DenseMap<OpOperand *, SmallVector<Value>>;
+using RegionBranchInverseSuccessorMapping =
+    DenseMap<Value, SmallVector<OpOperand *>>;
 
 /// This class represents a successor of a region. A region successor can 
either
 /// be another region, or the parent operation. If the successor is a region,
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td 
b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 2e654ba04ffe5..9366e5562b774 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -355,6 +355,15 @@ def RegionBranchOpInterface : 
OpInterface<"RegionBranchOpInterface"> {
         ::mlir::RegionBranchSuccessorMapping &mapping,
         std::optional<::mlir::RegionBranchPoint> src = std::nullopt);
 
+    /// Build a mapping from successor inputs to successor operands. This is
+    /// the same as "getSuccessorOperandInputMapping", but inverted.
+    void getSuccessorInputOperandMapping(
+        ::mlir::RegionBranchInverseSuccessorMapping &mapping);
+
+    /// Compute all values that a successor input could possibly have. If the
+    /// given value is not a successor input, return an empty set.
+    ::llvm::DenseSet<Value> 
computePossibleValuesOfSuccessorInput(::mlir::Value value);
+
     /// Return all possible region branch points: the region branch op itself
     /// and all region branch terminators.
     ::llvm::SmallVector<::mlir::RegionBranchPoint> getAllRegionBranchPoints();
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 0a123112cf68f..06b542d1c1dae 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -27,6 +27,7 @@
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/EquivalenceClasses.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallPtrSet.h"
@@ -291,102 +292,9 @@ struct MultiBlockExecuteInliner : public 
OpRewritePattern<ExecuteRegionOp> {
   }
 };
 
-// Pattern to eliminate ExecuteRegionOp results which forward external
-// values from the region. In case there are multiple yield operations,
-// all of them must have the same operands in order for the pattern to be
-// applicable.
-struct ExecuteRegionForwardingEliminator
-    : public OpRewritePattern<ExecuteRegionOp> {
-  using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExecuteRegionOp op,
-                                PatternRewriter &rewriter) const override {
-    if (op.getNumResults() == 0)
-      return failure();
-
-    SmallVector<Operation *> yieldOps;
-    for (Block &block : op.getRegion()) {
-      if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator()))
-        yieldOps.push_back(yield.getOperation());
-    }
-
-    if (yieldOps.empty())
-      return failure();
-
-    // Check if all yield operations have the same operands.
-    auto yieldOpsOperands = yieldOps[0]->getOperands();
-    for (auto *yieldOp : yieldOps) {
-      if (yieldOp->getOperands() != yieldOpsOperands)
-        return failure();
-    }
-
-    SmallVector<Value> externalValues;
-    SmallVector<Value> internalValues;
-    SmallVector<Value> opResultsToReplaceWithExternalValues;
-    SmallVector<Value> opResultsToKeep;
-    for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
-      if (isValueFromInsideRegion(yieldedValue, op)) {
-        internalValues.push_back(yieldedValue);
-        opResultsToKeep.push_back(op.getResult(index));
-      } else {
-        externalValues.push_back(yieldedValue);
-        opResultsToReplaceWithExternalValues.push_back(op.getResult(index));
-      }
-    }
-    // No yielded external values - nothing to do.
-    if (externalValues.empty())
-      return failure();
-
-    // There are yielded external values - create a new execute_region 
returning
-    // just the internal values.
-    SmallVector<Type> resultTypes;
-    for (Value value : internalValues)
-      resultTypes.push_back(value.getType());
-    auto newOp =
-        ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes));
-    newOp->setAttrs(op->getAttrs());
-
-    // Move old op's region to the new operation.
-    rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
-                                newOp.getRegion().end());
-
-    // Replace all yield operations with a new yield operation with updated
-    // results. scf.execute_region must have at least one yield operation.
-    for (auto *yieldOp : yieldOps) {
-      rewriter.setInsertionPoint(yieldOp);
-      rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp,
-                                                ValueRange(internalValues));
-    }
-
-    // Replace the old operation with the external values directly.
-    rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues,
-                                externalValues);
-    // Replace the old operation's remaining results with the new operation's
-    // results.
-    rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults());
-    rewriter.eraseOp(op);
-    return success();
-  }
-
-private:
-  bool isValueFromInsideRegion(Value value,
-                               ExecuteRegionOp executeRegionOp) const {
-    // Check if the value is defined within the execute_region
-    if (Operation *defOp = value.getDefiningOp())
-      return &executeRegionOp.getRegion() == defOp->getParentRegion();
-
-    // If it's a block argument, check if it's from within the region
-    if (BlockArgument blockArg = dyn_cast<BlockArgument>(value))
-      return &executeRegionOp.getRegion() == blockArg.getParentRegion();
-
-    return false; // Value is from outside the region
-  }
-};
-
 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner,
-              ExecuteRegionForwardingEliminator>(context);
+  results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
 }
 
 void ExecuteRegionOp::getSuccessorRegions(
@@ -989,146 +897,6 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase 
&rewriter, scf::ForOp forOp,
 }
 
 namespace {
-// Fold away ForOp iter arguments when:
-// 1) The op yields the iter arguments.
-// 2) The argument's corresponding outer region iterators (inputs) are yielded.
-// 3) The iter arguments have no use and the corresponding (operation) results
-// have no use.
-//
-// These arguments must be defined outside of the ForOp region and can just be
-// forwarded after simplifying the op inits, yields and returns.
-//
-// The implementation uses `inlineBlockBefore` to steal the content of the
-// original ForOp and avoid cloning.
-struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
-  using OpRewritePattern<scf::ForOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(scf::ForOp forOp,
-                                PatternRewriter &rewriter) const final {
-    bool canonicalize = false;
-
-    // An internal flat vector of block transfer
-    // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
-    // transformed block argument mappings. This plays the role of a
-    // IRMapping for the particular use case of calling into
-    // `inlineBlockBefore`.
-    int64_t numResults = forOp.getNumResults();
-    SmallVector<bool, 4> keepMask;
-    keepMask.reserve(numResults);
-    SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
-        newResultValues;
-    newBlockTransferArgs.reserve(1 + numResults);
-    newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
-    newIterArgs.reserve(forOp.getInitArgs().size());
-    newYieldValues.reserve(numResults);
-    newResultValues.reserve(numResults);
-    DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
-    for (auto [init, arg, result, yielded] :
-         llvm::zip(forOp.getInitArgs(),       // iter from outside
-                   forOp.getRegionIterArgs(), // iter inside region
-                   forOp.getResults(),        // op results
-                   forOp.getYieldedValues()   // iter yield
-                   )) {
-      // Forwarded is `true` when:
-      // 1) The region `iter` argument is yielded.
-      // 2) The region `iter` argument the corresponding input is yielded.
-      // 3) The region `iter` argument has no use, and the corresponding op
-      // result has no use.
-      bool forwarded = (arg == yielded) || (init == yielded) ||
-                       (arg.use_empty() && result.use_empty());
-      if (forwarded) {
-        canonicalize = true;
-        keepMask.push_back(false);
-        newBlockTransferArgs.push_back(init);
-        newResultValues.push_back(init);
-        continue;
-      }
-
-      // Check if a previous kept argument always has the same values for init
-      // and yielded values.
-      if (auto it = initYieldToArg.find({init, yielded});
-          it != initYieldToArg.end()) {
-        canonicalize = true;
-        keepMask.push_back(false);
-        auto [sameArg, sameResult] = it->second;
-        rewriter.replaceAllUsesWith(arg, sameArg);
-        rewriter.replaceAllUsesWith(result, sameResult);
-        // The replacement value doesn't matter because there are no uses.
-        newBlockTransferArgs.push_back(init);
-        newResultValues.push_back(init);
-        continue;
-      }
-
-      // This value is kept.
-      initYieldToArg.insert({{init, yielded}, {arg, result}});
-      keepMask.push_back(true);
-      newIterArgs.push_back(init);
-      newYieldValues.push_back(yielded);
-      newBlockTransferArgs.push_back(Value()); // placeholder with null value
-      newResultValues.push_back(Value());      // placeholder with null value
-    }
-
-    if (!canonicalize)
-      return failure();
-
-    scf::ForOp newForOp =
-        scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
-                           forOp.getUpperBound(), forOp.getStep(), newIterArgs,
-                           /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
-    newForOp->setAttrs(forOp->getAttrs());
-    Block &newBlock = newForOp.getRegion().front();
-
-    // Replace the null placeholders with newly constructed values.
-    newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
-    for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
-         idx != e; ++idx) {
-      Value &blockTransferArg = newBlockTransferArgs[1 + idx];
-      Value &newResultVal = newResultValues[idx];
-      assert((blockTransferArg && newResultVal) ||
-             (!blockTransferArg && !newResultVal));
-      if (!blockTransferArg) {
-        blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
-        newResultVal = newForOp.getResult(collapsedIdx++);
-      }
-    }
-
-    Block &oldBlock = forOp.getRegion().front();
-    assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
-           "unexpected argument size mismatch");
-
-    // No results case: the scf::ForOp builder already created a zero
-    // result terminator. Merge before this terminator and just get rid of the
-    // original terminator that has been merged in.
-    if (newIterArgs.empty()) {
-      auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
-      rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
-      rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
-      rewriter.replaceOp(forOp, newResultValues);
-      return success();
-    }
-
-    // No terminator case: merge and rewrite the merged terminator.
-    auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
-      OpBuilder::InsertionGuard g(rewriter);
-      rewriter.setInsertionPoint(mergedTerminator);
-      SmallVector<Value, 4> filteredOperands;
-      filteredOperands.reserve(newResultValues.size());
-      for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
-        if (keepMask[idx])
-          filteredOperands.push_back(mergedTerminator.getOperand(idx));
-      scf::YieldOp::create(rewriter, mergedTerminator.getLoc(),
-                           filteredOperands);
-    };
-
-    rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
-    auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
-    cloneFilteredTerminator(mergedYieldOp);
-    rewriter.eraseOp(mergedYieldOp);
-    rewriter.replaceOp(forOp, newResultValues);
-    return success();
-  }
-};
-
 /// Rewriting pattern that erases loops that are known not to iterate, replaces
 /// single-iteration loops with their bodies, and removes empty loops that
 /// iterate at least once and only return values defined outside of the loop.
@@ -1236,12 +1004,283 @@ struct ForOpTensorCastFolder : public 
OpRewritePattern<ForOp> {
   }
 };
 
+/// Is a defined before b?
+static bool isDefinedBefore(Value a, Value b) {
+  Region *aRegion = a.getParentRegion();
+  Region *bRegion = b.getParentRegion();
+
+  if (aRegion->getParentOp()->isProperAncestor(bRegion->getParentOp())) {
+    return true;
+  }
+  if (aRegion == bRegion) {
+    Block *aBlock = a.getParentBlock();
+    Block *bBlock = b.getParentBlock();
+    if (aBlock != bBlock)
+      return false;
+    if (isa<BlockArgument>(a))
+      return true;
+    if (isa<BlockArgument>(b))
+      return false;
+    return a.getDefiningOp()->isBeforeInBlock(b.getDefiningOp());
+  }
+
+  return false;
+}
+
+// Try to make successor inputs dead by replacing their uses with values that
+// are not successor inputs. This pattern enables additional canonicalization
+// opportunities for RemoveDeadValues.
+struct RemoveUsesOfIdenticalValues
+    : public OpInterfaceRewritePattern<RegionBranchOpInterface> {
+  using OpInterfaceRewritePattern<
+      RegionBranchOpInterface>::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(RegionBranchOpInterface op,
+                                PatternRewriter &rewriter) const override {
+    // TODO: ForallOp data flow is modeled incompletely.
+    if (isa<ForallOp>(op))
+      return failure();
+
+    // Gather all potential successor inputs. (Other values may also be
+    // included, but we're not doing anything with them.)
+    SmallVector<Value> values;
+    llvm::append_range(values, op->getResults());
+    for (Region &r : op->getRegions())
+      llvm::append_range(values, r.getArguments());
+
+    bool changed = false;
+    for (Value value : values) {
+      if (value.use_empty())
+        continue;
+      DenseSet<Value> possibleValues =
+          op.computePossibleValuesOfSuccessorInput(value);
+      if (possibleValues.size() == 1 && *possibleValues.begin() != value &&
+          isDefinedBefore(*possibleValues.begin(), value)) {
+        // Value is same as another value.
+        rewriter.replaceAllUsesWith(value, *possibleValues.begin());
+        changed = true;
+      }
+    }
+    return success(changed);
+  }
+};
+
+/// Pattern to remove dead values from region branch ops.
+struct RemoveDeadValues
+    : public OpInterfaceRewritePattern<RegionBranchOpInterface> {
+  using OpInterfaceRewritePattern<
+      RegionBranchOpInterface>::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(RegionBranchOpInterface op,
+                                PatternRewriter &rewriter) const override {
+    // TODO: ForallOp data flow is modeled incompletely.
+    if (isa<ForallOp>(op))
+      return failure();
+
+    // Compute tied values: values that must come as a set. If you remove one,
+    // you must remove all.
+    RegionBranchSuccessorMapping operandToInputs;
+    op.getSuccessorOperandInputMapping(operandToInputs);
+    llvm::EquivalenceClasses<Value> tiedSuccessorInputs;
+    for (const auto &[operand, inputs] : operandToInputs) {
+      assert(!inputs.empty() && "expected non-empty inputs");
+      Value firstInput = inputs.front();
+      tiedSuccessorInputs.insert(firstInput);
+      for (Value nextInput : llvm::drop_begin(inputs))
+        tiedSuccessorInputs.unionSets(firstInput, nextInput);
+    }
+
+    // Determine which values to remove and group them by block and operation.
+    SmallVector<Value> valuesToRemove;
+    DenseMap<Block *, BitVector> blockArgsToRemove;
+    DenseMap<Operation *, BitVector> resultsToRemove;
+    for (auto it = tiedSuccessorInputs.begin(), e = tiedSuccessorInputs.end();
+         it != e; ++it) {
+      if (!(*it)->isLeader())
+        continue;
+
+      // Value can be removed if it is dead and all other tied values are also
+      // dead.
+      bool allDead = true;
+      for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
+           memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
+        if (!memberIt->use_empty()) {
+          allDead = false;
+          break;
+        }
+      }
+      if (!allDead)
+        continue;
+
+      // Group values by block and operation.
+      for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
+           memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
+        if (auto arg = dyn_cast<BlockArgument>(*memberIt)) {
+          BitVector &vector =
+              blockArgsToRemove
+                  .try_emplace(arg.getOwner(),
+                               arg.getOwner()->getNumArguments(), false)
+                  .first->second;
+          vector.set(arg.getArgNumber());
+        } else {
+          OpResult result = cast<OpResult>(*memberIt);
+          BitVector &vector =
+              resultsToRemove
+                  .try_emplace(result.getDefiningOp(),
+                               result.getDefiningOp()->getNumResults(), false)
+                  .first->second;
+          vector.set(result.getResultNumber());
+        }
+        valuesToRemove.push_back(*memberIt);
+      }
+    }
+
+    if (valuesToRemove.empty())
+      return rewriter.notifyMatchFailure(op, "no values to remove");
+
+    // Find operands that must be removed together with the values.
+    RegionBranchInverseSuccessorMapping inputsToOperands;
+    op.getSuccessorInputOperandMapping(inputsToOperands);
+    DenseMap<Operation *, llvm::BitVector> operandsToRemove;
+    for (Value value : valuesToRemove) {
+      for (OpOperand *operand : inputsToOperands[value]) {
+        BitVector &vector =
+            operandsToRemove
+                .try_emplace(operand->getOwner(),
+                             operand->getOwner()->getNumOperands(), false)
+                .first->second;
+        vector.set(operand->getOperandNumber());
+      }
+    }
+
+    // Erase operands.
+    for (auto [op, operands] : operandsToRemove) {
+      rewriter.modifyOpInPlace(op, [&]() { op->eraseOperands(operands); });
+    }
+
+    // Erase block arguments.
+    for (auto [block, blockArgs] : blockArgsToRemove) {
+      rewriter.modifyOpInPlace(block->getParentOp(),
+                               [&]() { block->eraseArguments(blockArgs); });
+    }
+
+    // Erase op results.
+    // TODO: Can we move this to RewriterBase, so we have a uniform API,
+    // similar to eraseArguments?
+    for (auto [op, resultsToErase] : resultsToRemove) {
+      rewriter.setInsertionPoint(op);
+      SmallVector<Type> newResultTypes;
+      for (OpResult result : op->getResults())
+        if (!resultsToErase[result.getResultNumber()])
+          newResultTypes.push_back(result.getType());
+      OperationState state(op->getLoc(), op->getName().getStringRef(),
+                           op->getOperands(), newResultTypes, op->getAttrs());
+      for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
+        state.addRegion();
+      Operation *newOp = rewriter.create(state);
+      for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
+        // Move all blocks of `region` into `newRegion`.
+        Region &newRegion = newOp->getRegion(index);
+        rewriter.inlineRegionBefore(region, newRegion, newRegion.begin());
+      }
+
+      SmallVector<Value> newResults;
+      unsigned nextLiveResult = 0;
+      for (auto [index, result] : llvm::enumerate(op->getResults())) {
+        if (!resultsToErase[index]) {
+          newResults.push_back(newOp->getResult(nextLiveResult++));
+        } else {
+          newResults.push_back(Value());
+        }
+      }
+      rewriter.replaceOp(op, newResults);
+    }
+
+    return success();
+  }
+};
+
+void *getContainerOwnerOfValue(Value value) {
+  if (auto opResult = llvm::dyn_cast<OpResult>(value))
+    return opResult.getDefiningOp();
+  return llvm::cast<BlockArgument>(value).getOwner();
+}
+
+unsigned getArgOrResultNumber(Value value) {
+  if (auto opResult = llvm::dyn_cast<OpResult>(value))
+    return opResult.getResultNumber();
+  return llvm::cast<BlockArgument>(value).getArgNumber();
+}
+
+/// Pattern to make duplicate successor inputs dead. Two successor inputs are
+/// duplicate if their corresponding successor operands have the same values.
+struct RemoveDuplicateSuccessorInputUses
+    : public OpInterfaceRewritePattern<RegionBranchOpInterface> {
+  using OpInterfaceRewritePattern<
+      RegionBranchOpInterface>::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(RegionBranchOpInterface op,
+                                PatternRewriter &rewriter) const override {
+    // TODO: ForallOp data flow is modeled incompletely.
+    if (isa<ForallOp>(op))
+      return failure();
+
+    // Collect all successor inputs and sort them. When dropping the uses of a
+    // successor input, we'd like to also drop the uses of the same tied
+    // successor inputs. Otherwise, a set of tied successor inputs may not
+    // become entirely dead, which is required for RemoveDeadValues to erase
+    // them. (Sorting is not required for correctness.)
+    RegionBranchInverseSuccessorMapping inputsToOperands;
+    op.getSuccessorInputOperandMapping(inputsToOperands);
+    SmallVector<Value> inputs = llvm::to_vector(inputsToOperands.keys());
+    llvm::sort(inputs, [](Value a, Value b) {
+      return getArgOrResultNumber(a) < getArgOrResultNumber(b);
+    });
+
+    bool changed = false;
+    for (unsigned i = 0, e = inputs.size(); i < e; i++) {
+      Value input1 = inputs[i];
+      for (unsigned j = i + 1; j < e; j++) {
+        Value input2 = inputs[j];
+        // Nothing to do if input2 is already dead.
+        if (input2.use_empty())
+          continue;
+        // Replace only values of the same kind.
+        if (isa<BlockArgument>(input1) != isa<BlockArgument>(input2))
+          continue;
+        // Replace only values that belong to the same block / operation.
+        if (getContainerOwnerOfValue(input1) !=
+            getContainerOwnerOfValue(input2))
+          continue;
+
+        // Gather the predecessor value for each predecessor (region branch
+        // point). The two inputs are duplicates if each predecessor forwards
+        // the same value.
+        DenseMap<Operation *, Value> operands1, operands2;
+        for (OpOperand *operand : inputsToOperands[input1]) {
+          assert(!operands1.contains(operand->getOwner()));
+          operands1[operand->getOwner()] = operand->get();
+        }
+        for (OpOperand *operand : inputsToOperands[input2]) {
+          assert(!operands2.contains(operand->getOwner()));
+          operands2[operand->getOwner()] = operand->get();
+        }
+        if (operands1 == operands2) {
+          rewriter.replaceAllUsesWith(input2, input1);
+          changed = true;
+        }
+      }
+    }
+    return success(changed);
+  }
+};
 } // namespace
 
 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, 
ForOpTensorCastFolder>(
-      context);
+  results.add<SimplifyTrivialLoops, ForOpTensorCastFolder,
+              RemoveUsesOfIdenticalValues, RemoveDeadValues,
+              RemoveDuplicateSuccessorInputUses>(context);
 }
 
 std::optional<APInt> ForOp::getConstantStep() {
@@ -2409,61 +2448,6 @@ void IfOp::getRegionInvocationBounds(
 }
 
 namespace {
-// Pattern to remove unused IfOp results.
-struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
-  using OpRewritePattern<IfOp>::OpRewritePattern;
-
-  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
-                    PatternRewriter &rewriter) const {
-    // Move all operations to the destination block.
-    rewriter.mergeBlocks(source, dest);
-    // Replace the yield op by one that returns only the used values.
-    auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
-    SmallVector<Value, 4> usedOperands;
-    llvm::transform(usedResults, std::back_inserter(usedOperands),
-                    [&](OpResult result) {
-                      return yieldOp.getOperand(result.getResultNumber());
-                    });
-    rewriter.modifyOpInPlace(yieldOp,
-                             [&]() { yieldOp->setOperands(usedOperands); });
-  }
-
-  LogicalResult matchAndRewrite(IfOp op,
-                                PatternRewriter &rewriter) const override {
-    // Compute the list of used results.
-    SmallVector<OpResult, 4> usedResults;
-    llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
-                  [](OpResult result) { return !result.use_empty(); });
-
-    // Replace the operation if only a subset of its results have uses.
-    if (usedResults.size() == op.getNumResults())
-      return failure();
-
-    // Compute the result types of the replacement operation.
-    SmallVector<Type, 4> newTypes;
-    llvm::transform(usedResults, std::back_inserter(newTypes),
-                    [](OpResult result) { return result.getType(); });
-
-    // Create a replacement operation with empty then and else regions.
-    auto newOp =
-        IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
-    rewriter.createBlock(&newOp.getThenRegion());
-    rewriter.createBlock(&newOp.getElseRegion());
-
-    // Move the bodies and replace the terminators (note there is a then and
-    // an else region since the operation returns results).
-    transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
-    transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
-
-    // Replace the operation by the new one.
-    SmallVector<Value, 4> repResults(op.getNumResults());
-    for (const auto &en : llvm::enumerate(usedResults))
-      repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
-    rewriter.replaceOp(op, repResults);
-    return success();
-  }
-};
-
 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
   using OpRewritePattern<IfOp>::OpRewritePattern;
 
@@ -3034,8 +3018,8 @@ void IfOp::getCanonicalizationPatterns(RewritePatternSet 
&results,
                                        MLIRContext *context) {
   results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
               ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
-              RemoveStaticCondition, RemoveUnusedResults,
-              ReplaceIfYieldWithConditionOrValue>(context);
+              RemoveStaticCondition, ReplaceIfYieldWithConditionOrValue>(
+      context);
 }
 
 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
@@ -3873,390 +3857,6 @@ struct WhileConditionTruth : public 
OpRewritePattern<WhileOp> {
   }
 };
 
-/// Remove loop invariant arguments from `before` block of scf.while.
-/// A before block argument is considered loop invariant if :-
-///   1. i-th yield operand is equal to the i-th while operand.
-///   2. i-th yield operand is k-th after block argument which is (k+1)-th
-///      condition operand AND this (k+1)-th condition operand is equal to i-th
-///      iter argument/while operand.
-/// For the arguments which are removed, their uses inside scf.while
-/// are replaced with their corresponding initial value.
-///
-/// Eg:
-///    INPUT :-
-///    %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
-///                                     ..., %argN_before = %N)
-///           {
-///                ...
-///                scf.condition(%cond) %arg1_before, %arg0_before,
-///                                     %arg2_before, %arg0_before, ...
-///           } do {
-///             ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
-///                  ..., %argK_after):
-///                ...
-///                scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
-///           }
-///
-///    OUTPUT :-
-///    %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
-///                                     %N)
-///           {
-///                ...
-///                scf.condition(%cond) %b, %a, %arg2_before, %a, ...
-///           } do {
-///             ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
-///                  ..., %argK_after):
-///                ...
-///                scf.yield %arg1_after, ..., %argN
-///           }
-///
-///    EXPLANATION:
-///      We iterate over each yield operand.
-///        1. 0-th yield operand %arg0_after_2 is 4-th condition operand
-///           %arg0_before, which in turn is the 0-th iter argument. So we
-///           remove 0-th before block argument and yield operand, and replace
-///           all uses of the 0-th before block argument with its initial value
-///           %a.
-///        2. 1-th yield operand %b is equal to the 1-th iter arg's initial
-///           value. So we remove this operand and the corresponding before
-///           block argument and replace all uses of 1-th before block argument
-///           with %b.
-struct RemoveLoopInvariantArgsFromBeforeBlock
-    : public OpRewritePattern<WhileOp> {
-  using OpRewritePattern<WhileOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(WhileOp op,
-                                PatternRewriter &rewriter) const override {
-    Block &afterBlock = *op.getAfterBody();
-    Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
-    ConditionOp condOp = op.getConditionOp();
-    OperandRange condOpArgs = condOp.getArgs();
-    Operation *yieldOp = afterBlock.getTerminator();
-    ValueRange yieldOpArgs = yieldOp->getOperands();
-
-    bool canSimplify = false;
-    for (const auto &it :
-         llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
-      auto index = static_cast<unsigned>(it.index());
-      auto [initVal, yieldOpArg] = it.value();
-      // If i-th yield operand is equal to the i-th operand of the scf.while,
-      // the i-th before block argument is a loop invariant.
-      if (yieldOpArg == initVal) {
-        canSimplify = true;
-        break;
-      }
-      // If the i-th yield operand is k-th after block argument, then we check
-      // if the (k+1)-th condition op operand is equal to either the i-th 
before
-      // block argument or the initial value of i-th before block argument. If
-      // the comparison results `true`, i-th before block argument is a loop
-      // invariant.
-      auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
-      if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
-        Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
-        if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
-          canSimplify = true;
-          break;
-        }
-      }
-    }
-
-    if (!canSimplify)
-      return failure();
-
-    SmallVector<Value> newInitArgs, newYieldOpArgs;
-    DenseMap<unsigned, Value> beforeBlockInitValMap;
-    SmallVector<Location> newBeforeBlockArgLocs;
-    for (const auto &it :
-         llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
-      auto index = static_cast<unsigned>(it.index());
-      auto [initVal, yieldOpArg] = it.value();
-
-      // If i-th yield operand is equal to the i-th operand of the scf.while,
-      // the i-th before block argument is a loop invariant.
-      if (yieldOpArg == initVal) {
-        beforeBlockInitValMap.insert({index, initVal});
-        continue;
-      } else {
-        // If the i-th yield operand is k-th after block argument, then we 
check
-        // if the (k+1)-th condition op operand is equal to either the i-th
-        // before block argument or the initial value of i-th before block
-        // argument. If the comparison results `true`, i-th before block
-        // argument is a loop invariant.
-        auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
-        if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
-          Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
-          if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
-            beforeBlockInitValMap.insert({index, initVal});
-            continue;
-          }
-        }
-      }
-      newInitArgs.emplace_back(initVal);
-      newYieldOpArgs.emplace_back(yieldOpArg);
-      newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
-    }
-
-    {
-      OpBuilder::InsertionGuard g(rewriter);
-      rewriter.setInsertionPoint(yieldOp);
-      rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
-    }
-
-    auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(),
-                                    newInitArgs);
-
-    Block &newBeforeBlock = *rewriter.createBlock(
-        &newWhile.getBefore(), /*insertPt*/ {},
-        ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
-
-    Block &beforeBlock = *op.getBeforeBody();
-    SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
-    // For each i-th before block argument we find it's replacement value as :-
-    //   1. If i-th before block argument is a loop invariant, we fetch it's
-    //      initial value from `beforeBlockInitValMap` by querying for key `i`.
-    //   2. Else we fetch j-th new before block argument as the replacement
-    //      value of i-th before block argument.
-    for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) 
{
-      // If the index 'i' argument was a loop invariant we fetch it's initial
-      // value from `beforeBlockInitValMap`.
-      if (beforeBlockInitValMap.count(i) != 0)
-        newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
-      else
-        newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
-    }
-
-    rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
-    rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
-                                newWhile.getAfter().begin());
-
-    rewriter.replaceOp(op, newWhile.getResults());
-    return success();
-  }
-};
-
-/// Remove loop invariant value from result (condition op) of scf.while.
-/// A value is considered loop invariant if the final value yielded by
-/// scf.condition is defined outside of the `before` block. We remove the
-/// corresponding argument in `after` block and replace the use with the value.
-/// We also replace the use of the corresponding result of scf.while with the
-/// value.
-///
-/// Eg:
-///    INPUT :-
-///    %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
-///                                             %argN_before = %N) {
-///                ...
-///                scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
-///           } do {
-///             ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
-///                ...
-///                some_func(%arg1_after)
-///                ...
-///                scf.yield %arg0_after, %arg2_after, ..., %argN_after
-///           }
-///
-///    OUTPUT :-
-///    %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
-///                ...
-///                scf.condition(%cond) %arg0, %arg1, ..., %argM
-///           } do {
-///             ^bb0(%arg0, %arg3, ..., %argM):
-///                ...
-///                some_func(%a)
-///                ...
-///                scf.yield %arg0, %b, ..., %argN
-///           }
-///
-///     EXPLANATION:
-///       1. The 1-th and 2-th operand of scf.condition are defined outside the
-///          before block of scf.while, so they get removed.
-///       2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
-///          replaced by %b.
-///       3. The corresponding after block argument %arg1_after's uses are
-///          replaced by %a and %arg2_after's uses are replaced by %b.
-struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
-  using OpRewritePattern<WhileOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(WhileOp op,
-                                PatternRewriter &rewriter) const override {
-    Block &beforeBlock = *op.getBeforeBody();
-    ConditionOp condOp = op.getConditionOp();
-    OperandRange condOpArgs = condOp.getArgs();
-
-    bool canSimplify = false;
-    for (Value condOpArg : condOpArgs) {
-      // Those values not defined within `before` block will be considered as
-      // loop invariant values. We map the corresponding `index` with their
-      // value.
-      if (condOpArg.getParentBlock() != &beforeBlock) {
-        canSimplify = true;
-        break;
-      }
-    }
-
-    if (!canSimplify)
-      return failure();
-
-    Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
-
-    SmallVector<Value> newCondOpArgs;
-    SmallVector<Type> newAfterBlockType;
-    DenseMap<unsigned, Value> condOpInitValMap;
-    SmallVector<Location> newAfterBlockArgLocs;
-    for (const auto &it : llvm::enumerate(condOpArgs)) {
-      auto index = static_cast<unsigned>(it.index());
-      Value condOpArg = it.value();
-      // Those values not defined within `before` block will be considered as
-      // loop invariant values. We map the corresponding `index` with their
-      // value.
-      if (condOpArg.getParentBlock() != &beforeBlock) {
-        condOpInitValMap.insert({index, condOpArg});
-      } else {
-        newCondOpArgs.emplace_back(condOpArg);
-        newAfterBlockType.emplace_back(condOpArg.getType());
-        newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
-      }
-    }
-
-    {
-      OpBuilder::InsertionGuard g(rewriter);
-      rewriter.setInsertionPoint(condOp);
-      rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
-                                               newCondOpArgs);
-    }
-
-    auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType,
-                                    op.getOperands());
-
-    Block &newAfterBlock =
-        *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
-                              newAfterBlockType, newAfterBlockArgLocs);
-
-    Block &afterBlock = *op.getAfterBody();
-    // Since a new scf.condition op was created, we need to fetch the new
-    // `after` block arguments which will be used while replacing operations of
-    // previous scf.while's `after` blocks. We'd also be fetching new result
-    // values too.
-    SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
-    SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
-    for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
-      Value afterBlockArg, result;
-      // If index 'i' argument was loop invariant we fetch it's value from the
-      // `condOpInitMap` map.
-      if (condOpInitValMap.count(i) != 0) {
-        afterBlockArg = condOpInitValMap[i];
-        result = afterBlockArg;
-      } else {
-        afterBlockArg = newAfterBlock.getArgument(j);
-        result = newWhile.getResult(j);
-        j++;
-      }
-      newAfterBlockArgs[i] = afterBlockArg;
-      newWhileResults[i] = result;
-    }
-
-    rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
-    rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
-                                newWhile.getBefore().begin());
-
-    rewriter.replaceOp(op, newWhileResults);
-    return success();
-  }
-};
-
-/// Remove WhileOp results that are also unused in 'after' block.
-///
-///  %0:2 = scf.while () : () -> (i32, i64) {
-///    %condition = "test.condition"() : () -> i1
-///    %v1 = "test.get_some_value"() : () -> i32
-///    %v2 = "test.get_some_value"() : () -> i64
-///    scf.condition(%condition) %v1, %v2 : i32, i64
-///  } do {
-///  ^bb0(%arg0: i32, %arg1: i64):
-///    "test.use"(%arg0) : (i32) -> ()
-///    scf.yield
-///  }
-///  return %0#0 : i32
-///
-/// becomes
-///  %0 = scf.while () : () -> (i32) {
-///    %condition = "test.condition"() : () -> i1
-///    %v1 = "test.get_some_value"() : () -> i32
-///    %v2 = "test.get_some_value"() : () -> i64
-///    scf.condition(%condition) %v1 : i32
-///  } do {
-///  ^bb0(%arg0: i32):
-///    "test.use"(%arg0) : (i32) -> ()
-///    scf.yield
-///  }
-///  return %0 : i32
-struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
-  using OpRewritePattern<WhileOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(WhileOp op,
-                                PatternRewriter &rewriter) const override {
-    auto term = op.getConditionOp();
-    auto afterArgs = op.getAfterArguments();
-    auto termArgs = term.getArgs();
-
-    // Collect results mapping, new terminator args and new result types.
-    SmallVector<unsigned> newResultsIndices;
-    SmallVector<Type> newResultTypes;
-    SmallVector<Value> newTermArgs;
-    SmallVector<Location> newArgLocs;
-    bool needUpdate = false;
-    for (const auto &it :
-         llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
-      auto i = static_cast<unsigned>(it.index());
-      Value result = std::get<0>(it.value());
-      Value afterArg = std::get<1>(it.value());
-      Value termArg = std::get<2>(it.value());
-      if (result.use_empty() && afterArg.use_empty()) {
-        needUpdate = true;
-      } else {
-        newResultsIndices.emplace_back(i);
-        newTermArgs.emplace_back(termArg);
-        newResultTypes.emplace_back(result.getType());
-        newArgLocs.emplace_back(result.getLoc());
-      }
-    }
-
-    if (!needUpdate)
-      return failure();
-
-    {
-      OpBuilder::InsertionGuard g(rewriter);
-      rewriter.setInsertionPoint(term);
-      rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
-                                               newTermArgs);
-    }
-
-    auto newWhile =
-        WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits());
-
-    Block &newAfterBlock = *rewriter.createBlock(
-        &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
-
-    // Build new results list and new after block args (unused entries will be
-    // null).
-    SmallVector<Value> newResults(op.getNumResults());
-    SmallVector<Value> newAfterBlockArgs(op.getNumResults());
-    for (const auto &it : llvm::enumerate(newResultsIndices)) {
-      newResults[it.value()] = newWhile.getResult(it.index());
-      newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
-    }
-
-    rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
-                                newWhile.getBefore().begin());
-
-    Block &afterBlock = *op.getAfterBody();
-    rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
-
-    rewriter.replaceOp(op, newResults);
-    return success();
-  }
-};
-
 /// Replace operations equivalent to the condition in the do block with true,
 /// since otherwise the block would not be evaluated.
 ///
@@ -4321,127 +3921,6 @@ struct WhileCmpCond : public 
OpRewritePattern<scf::WhileOp> {
   }
 };
 
-/// Remove unused init/yield args.
-struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
-  using OpRewritePattern<WhileOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(WhileOp op,
-                                PatternRewriter &rewriter) const override {
-
-    if (!llvm::any_of(op.getBeforeArguments(),
-                      [](Value arg) { return arg.use_empty(); }))
-      return rewriter.notifyMatchFailure(op, "No args to remove");
-
-    YieldOp yield = op.getYieldOp();
-
-    // Collect results mapping, new terminator args and new result types.
-    SmallVector<Value> newYields;
-    SmallVector<Value> newInits;
-    llvm::BitVector argsToErase;
-
-    size_t argsCount = op.getBeforeArguments().size();
-    newYields.reserve(argsCount);
-    newInits.reserve(argsCount);
-    argsToErase.reserve(argsCount);
-    for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
-             op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
-      if (beforeArg.use_empty()) {
-        argsToErase.push_back(true);
-      } else {
-        argsToErase.push_back(false);
-        newYields.emplace_back(yieldValue);
-        newInits.emplace_back(initValue);
-      }
-    }
-
-    Block &beforeBlock = *op.getBeforeBody();
-    Block &afterBlock = *op.getAfterBody();
-
-    beforeBlock.eraseArguments(argsToErase);
-
-    Location loc = op.getLoc();
-    auto newWhileOp =
-        WhileOp::create(rewriter, loc, op.getResultTypes(), newInits,
-                        /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
-    Block &newBeforeBlock = *newWhileOp.getBeforeBody();
-    Block &newAfterBlock = *newWhileOp.getAfterBody();
-
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(yield);
-    rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
-
-    rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
-                         newBeforeBlock.getArguments());
-    rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
-                         newAfterBlock.getArguments());
-
-    rewriter.replaceOp(op, newWhileOp.getResults());
-    return success();
-  }
-};
-
-/// Remove duplicated ConditionOp args.
-struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(WhileOp op,
-                                PatternRewriter &rewriter) const override {
-    ConditionOp condOp = op.getConditionOp();
-    ValueRange condOpArgs = condOp.getArgs();
-
-    llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs);
-
-    if (argsSet.size() == condOpArgs.size())
-      return rewriter.notifyMatchFailure(op, "No results to remove");
-
-    llvm::SmallDenseMap<Value, unsigned> argsMap;
-    SmallVector<Value> newArgs;
-    argsMap.reserve(condOpArgs.size());
-    newArgs.reserve(condOpArgs.size());
-    for (Value arg : condOpArgs) {
-      if (!argsMap.count(arg)) {
-        auto pos = static_cast<unsigned>(argsMap.size());
-        argsMap.insert({arg, pos});
-        newArgs.emplace_back(arg);
-      }
-    }
-
-    ValueRange argsRange(newArgs);
-
-    Location loc = op.getLoc();
-    auto newWhileOp =
-        scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), 
op.getInits(),
-                             /*beforeBody*/ nullptr,
-                             /*afterBody*/ nullptr);
-    Block &newBeforeBlock = *newWhileOp.getBeforeBody();
-    Block &newAfterBlock = *newWhileOp.getAfterBody();
-
-    SmallVector<Value> afterArgsMapping;
-    SmallVector<Value> resultsMapping;
-    for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
-      auto it = argsMap.find(arg);
-      assert(it != argsMap.end());
-      auto pos = it->second;
-      afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
-      resultsMapping.emplace_back(newWhileOp->getResult(pos));
-    }
-
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(condOp);
-    rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
-                                             argsRange);
-
-    Block &beforeBlock = *op.getBeforeBody();
-    Block &afterBlock = *op.getAfterBody();
-
-    rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
-                         newBeforeBlock.getArguments());
-    rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
-    rewriter.replaceOp(op, resultsMapping);
-    return success();
-  }
-};
-
 /// If both ranges contain same values return mappping indices from args2 to
 /// args1. Otherwise return std::nullopt.
 static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
@@ -4532,11 +4011,8 @@ struct WhileOpAlignBeforeArgs : public 
OpRewritePattern<WhileOp> {
 
 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                           MLIRContext *context) {
-  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
-              RemoveLoopInvariantValueYielded, WhileConditionTruth,
-              WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
-              WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
-      context);
+  results.add<WhileConditionTruth, WhileCmpCond, WhileOpAlignBeforeArgs,
+              WhileMoveIfDown>(context);
 }
 
 
//===----------------------------------------------------------------------===//
@@ -4711,59 +4187,9 @@ struct FoldConstantCase : 
OpRewritePattern<scf::IndexSwitchOp> {
   }
 };
 
-/// Canonicalization patterns that folds away dead results of
-/// "scf.index_switch" ops.
-struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
-  using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(IndexSwitchOp op,
-                                PatternRewriter &rewriter) const override {
-    // Find dead results.
-    BitVector deadResults(op.getNumResults(), false);
-    SmallVector<Type> newResultTypes;
-    for (auto [idx, result] : llvm::enumerate(op.getResults())) {
-      if (!result.use_empty()) {
-        newResultTypes.push_back(result.getType());
-      } else {
-        deadResults[idx] = true;
-      }
-    }
-    if (!deadResults.any())
-      return rewriter.notifyMatchFailure(op, "no dead results to fold");
-
-    // Create new op without dead results and inline case regions.
-    auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
-                                       op.getArg(), op.getCases(),
-                                       op.getCaseRegions().size());
-    auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
-      rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
-      // Remove respective operands from yield op.
-      Operation *terminator = newRegion.front().getTerminator();
-      assert(isa<YieldOp>(terminator) && "expected yield op");
-      rewriter.modifyOpInPlace(
-          terminator, [&]() { terminator->eraseOperands(deadResults); });
-    };
-    for (auto [oldRegion, newRegion] :
-         llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
-      inlineCaseRegion(oldRegion, newRegion);
-    inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
-
-    // Replace op with new op.
-    SmallVector<Value> newResults(op.getNumResults(), Value());
-    unsigned nextNewResult = 0;
-    for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
-      if (deadResults[idx])
-        continue;
-      newResults[idx] = newOp.getResult(nextNewResult++);
-    }
-    rewriter.replaceOp(op, newResults);
-    return success();
-  }
-};
-
 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
+  results.add<FoldConstantCase>(context);
 }
 
 
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp 
b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index d393ddb8d8336..ed94205d32f19 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -521,6 +521,45 @@ void 
RegionBranchOpInterface::getSuccessorOperandInputMapping(
   }
 }
 
+void RegionBranchOpInterface::getSuccessorInputOperandMapping(
+    RegionBranchInverseSuccessorMapping &mapping) {
+  RegionBranchSuccessorMapping operandToInputs;
+  getSuccessorOperandInputMapping(operandToInputs);
+  for (const auto &[operand, inputs] : operandToInputs) {
+    for (Value input : inputs)
+      mapping[input].push_back(operand);
+  }
+}
+
+DenseSet<Value>
+RegionBranchOpInterface::computePossibleValuesOfSuccessorInput(Value value) {
+  RegionBranchInverseSuccessorMapping inputToOperands;
+  getSuccessorInputOperandMapping(inputToOperands);
+
+  DenseSet<Value> possibleValues;
+  DenseSet<Value> visited;
+  SmallVector<Value> worklist;
+
+  // Starting with the given value, trace back all predecessor values (i.e.,
+  // preceding successor operands) and add them to the set of possible values.
+  // If the successor operand is again a successor input, do not add it to
+  // result set, but instead continue the traversal.
+  worklist.push_back(value);
+  while (!worklist.empty()) {
+    Value next = worklist.pop_back_val();
+    auto it = inputToOperands.find(next);
+    if (it == inputToOperands.end()) {
+      possibleValues.insert(next);
+      continue;
+    }
+    for (OpOperand *operand : it->second)
+      if (visited.insert(operand->get()).second)
+        worklist.push_back(operand->get());
+  }
+
+  return possibleValues;
+}
+
 SmallVector<RegionBranchPoint>
 RegionBranchOpInterface::getAllRegionBranchPoints() {
   SmallVector<RegionBranchPoint> branchPoints;
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir 
b/mlir/test/Dialect/SCF/canonicalize.mlir
index d5d0aee3bbe25..11dc4f04af32e 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1071,17 +1071,17 @@ func.func @invariant_loop_args_in_same_order(%f_arg0: 
tensor<i32>) -> (tensor<i3
 // CHECK:    %[[ZERO:.*]] = arith.constant dense<0>
 // CHECK:    %[[ONE:.*]] = arith.constant dense<1>
 // CHECK:    %[[CST42:.*]] = arith.constant dense<42>
-// CHECK:    %[[WHILE:.*]]:3 = scf.while (%[[ARG0:.*]] = %[[ZERO]], 
%[[ARG2:.*]] = %[[ONE]], %[[ARG3:.*]] = %[[ONE]])
+// CHECK:    %[[WHILE:.*]]:2 = scf.while (%[[ARG0:.*]] = %[[ZERO]], 
%[[ARG2:.*]] = %[[ONE]]) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, 
tensor<i32>)
 // CHECK:       arith.cmpi slt, %[[ARG0]], %{{.*}}
 // CHECK:       tensor.extract %{{.*}}[]
-// CHECK:       scf.condition(%{{.*}}) %[[ARG0]], %[[ARG2]], %[[ARG3]]
+// CHECK:       scf.condition(%{{.*}}) %[[ARG0]], %[[ARG2]]
 // CHECK:    } do {
-// CHECK:     ^{{.*}}(%[[ARG0:.*]]: tensor<i32>, %[[ARG2:.*]]: tensor<i32>, 
%[[ARG3:.*]]: tensor<i32>):
+// CHECK:     ^{{.*}}(%[[ARG0:.*]]: tensor<i32>, %[[ARG2:.*]]: tensor<i32>):
 // CHECK:       %[[VAL0:.*]] = arith.addi %[[ARG0]], %[[FUNC_ARG0]]
-// CHECK:       %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG3]]
-// CHECK:       scf.yield %[[VAL0]], %[[VAL1]], %[[VAL1]]
+// CHECK:       %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG2]]
+// CHECK:       scf.yield %[[VAL0]], %[[VAL1]]
 // CHECK:    }
-// CHECK:    return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#2, 
%[[ZERO]]
+// CHECK:    return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#1, 
%[[ZERO]]
 
 // CHECK-LABEL: @while_loop_invariant_argument_different_order
 func.func @while_loop_invariant_argument_different_order(%arg : tensor<i32>) 
-> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, 
tensor<i32>) {
@@ -1736,11 +1736,11 @@ module {
 
 // Test case with multiple scf.yield ops with at least one different operand, 
then no change.
 
-// CHECK:           %[[VAL_3:.*]]:2 = scf.execute_region -> (memref<1x60xui8>, 
memref<1x120xui8>) no_inline {
+// CHECK:           %[[VAL_3:.*]] = scf.execute_region -> memref<1x120xui8> 
no_inline {
 // CHECK:           ^bb1:
-// CHECK:             scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, 
memref<1x120xui8>
+// CHECK:             scf.yield %{{.*}} : memref<1x120xui8>
 // CHECK:           ^bb2:
-// CHECK:             scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, 
memref<1x120xui8>
+// CHECK:             scf.yield  %{{.*}} : memref<1x120xui8>
 // CHECK:           }
 
 module {
@@ -2178,16 +2178,14 @@ func.func @scf_for_all_step_size_0()  {
 //  CHECK-SAME:     %[[arg0:.*]]: index
 //   CHECK-DAG:   %[[c10:.*]] = arith.constant 10
 //   CHECK-DAG:   %[[c11:.*]] = arith.constant 11
-//       CHECK:   %[[switch:.*]] = scf.index_switch %[[arg0]] -> index
+//       CHECK:   scf.index_switch %[[arg0]]
 //       CHECK:   case 1 {
 //       CHECK:     memref.store %[[c10]]
-//       CHECK:     scf.yield %[[arg0]] : index
 //       CHECK:   } 
 //       CHECK:   default {
 //       CHECK:     memref.store %[[c11]]
-//       CHECK:     scf.yield %[[arg0]] : index
 //       CHECK:   }
-//       CHECK:   return %[[switch]]
+//       CHECK:   return %[[arg0]]
 func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> 
index {
   %non_live, %live = scf.index_switch %arg0 -> i32, index
   case 1 {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir 
b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
index 5f2aa5e3a2736..b0a1af31d8806 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
@@ -129,13 +129,13 @@ func.func @matmul_sparse_rhs(%a: tensor<10x20xf32>,
 // CHECK:             %[[VAL_29:.*]] = memref.load 
%[[VAL_9]]{{\[}}%[[VAL_28]]] : memref<?xindex>
 // CHECK:             %[[VAL_30:.*]] = memref.load 
%[[VAL_12]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:             %[[VAL_31:.*]] = memref.load 
%[[VAL_12]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK:             %[[VAL_32:.*]]:4 = scf.while (%[[VAL_33:.*]] = 
%[[VAL_27]], %[[VAL_34:.*]] = %[[VAL_30]], %[[VAL_35:.*]] = %[[VAL_26]], 
%[[VAL_36:.*]] = %[[VAL_21]]) : (index, index, index, tensor<4x4xf64, 
#sparse{{[0-9]*}}>) -> (index, index, index, tensor<4x4xf64, 
#sparse{{[0-9]*}}>) {
+// CHECK:             %[[VAL_32:.*]]:3 = scf.while (%[[VAL_33:.*]] = 
%[[VAL_27]], %[[VAL_34:.*]] = %[[VAL_30]], %[[VAL_35:.*]] = %[[VAL_26]]) : 
(index, index, index) -> (index, index, index) {
 // CHECK:               %[[VAL_37:.*]] = arith.cmpi ult, %[[VAL_33]], 
%[[VAL_29]] : index
 // CHECK:               %[[VAL_38:.*]] = arith.cmpi ult, %[[VAL_34]], 
%[[VAL_31]] : index
 // CHECK:               %[[VAL_39:.*]] = arith.andi %[[VAL_37]], %[[VAL_38]] : 
i1
-// CHECK:               scf.condition(%[[VAL_39]]) %[[VAL_33]], %[[VAL_34]], 
%[[VAL_35]], %[[VAL_36]] : index, index, index, tensor<4x4xf64, 
#sparse{{[0-9]*}}>
+// CHECK:               scf.condition(%[[VAL_39]]) %[[VAL_33]], %[[VAL_34]], 
%[[VAL_35]] : index, index, index
 // CHECK:             } do {
-// CHECK:             ^bb0(%[[VAL_40:.*]]: index, %[[VAL_41:.*]]: index, 
%[[VAL_42:.*]]: index, %[[VAL_43:.*]]: tensor<4x4xf64, #sparse{{[0-9]*}}>):
+// CHECK:             ^bb0(%[[VAL_40:.*]]: index, %[[VAL_41:.*]]: index, 
%[[VAL_42:.*]]: index):
 // CHECK:               %[[VAL_44:.*]] = memref.load 
%[[VAL_10]]{{\[}}%[[VAL_40]]] : memref<?xindex>
 // CHECK:               %[[VAL_45:.*]] = memref.load 
%[[VAL_13]]{{\[}}%[[VAL_41]]] : memref<?xindex>
 // CHECK:               %[[VAL_46:.*]] = arith.cmpi ult, %[[VAL_45]], 
%[[VAL_44]] : index
@@ -143,7 +143,7 @@ func.func @matmul_sparse_rhs(%a: tensor<10x20xf32>,
 // CHECK:               %[[VAL_48:.*]] = arith.cmpi eq, %[[VAL_44]], 
%[[VAL_47]] : index
 // CHECK:               %[[VAL_49:.*]] = arith.cmpi eq, %[[VAL_45]], 
%[[VAL_47]] : index
 // CHECK:               %[[VAL_50:.*]] = arith.andi %[[VAL_48]], %[[VAL_49]] : 
i1
-// CHECK:               %[[VAL_51:.*]]:2 = scf.if %[[VAL_50]] -> (index, 
tensor<4x4xf64, #sparse{{[0-9]*}}>) {
+// CHECK:               %[[VAL_51:.*]] = scf.if %[[VAL_50]] -> (index) {
 // CHECK:                 %[[VAL_52:.*]] = memref.load 
%[[VAL_11]]{{\[}}%[[VAL_40]]] : memref<?xf64>
 // CHECK:                 %[[VAL_53:.*]] = memref.load 
%[[VAL_14]]{{\[}}%[[VAL_41]]] : memref<?xindex>
 // CHECK:                 %[[VAL_54:.*]] = arith.addi %[[VAL_41]], %[[VAL_3]] 
: index
@@ -167,9 +167,9 @@ func.func @matmul_sparse_rhs(%a: tensor<10x20xf32>,
 // CHECK:                   memref.store %[[VAL_63]], 
%[[VAL_23]]{{\[}}%[[VAL_59]]] : memref<?xf64>
 // CHECK:                   scf.yield %[[VAL_68:.*]] : index
 // CHECK:                 }
-// CHECK:                 scf.yield %[[VAL_69:.*]], %[[VAL_43]] : index, 
tensor<4x4xf64, #sparse{{[0-9]*}}>
+// CHECK:                 scf.yield %[[VAL_69:.*]] : index
 // CHECK:               } else {
-// CHECK:                 scf.yield %[[VAL_42]], %[[VAL_43]] : index, 
tensor<4x4xf64, #sparse{{[0-9]*}}>
+// CHECK:                 scf.yield %[[VAL_42]] : index
 // CHECK:               }
 // CHECK:               %[[VAL_70:.*]] = arith.cmpi eq, %[[VAL_44]], 
%[[VAL_47]] : index
 // CHECK:               %[[VAL_71:.*]] = arith.addi %[[VAL_40]], %[[VAL_3]] : 
index
@@ -177,9 +177,9 @@ func.func @matmul_sparse_rhs(%a: tensor<10x20xf32>,
 // CHECK:               %[[VAL_73:.*]] = arith.cmpi eq, %[[VAL_45]], 
%[[VAL_47]] : index
 // CHECK:               %[[VAL_74:.*]] = arith.addi %[[VAL_41]], %[[VAL_3]] : 
index
 // CHECK:               %[[VAL_75:.*]] = arith.select %[[VAL_73]], 
%[[VAL_74]], %[[VAL_41]] : index
-// CHECK:               scf.yield %[[VAL_72]], %[[VAL_75]], %[[VAL_76:.*]]#0, 
%[[VAL_76]]#1 : index, index, index, tensor<4x4xf64, #sparse{{[0-9]*}}>
+// CHECK:               scf.yield %[[VAL_72]], %[[VAL_75]], %[[VAL_76:.*]] : 
index, index, index
 // CHECK:             }
-// CHECK:             %[[VAL_77:.*]] = sparse_tensor.compress %[[VAL_23]], 
%[[VAL_24]], %[[VAL_25]], %[[VAL_78:.*]]#2 into %[[VAL_78]]#3{{\[}}%[[VAL_22]]] 
: memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<4x4xf64, 
#sparse{{[0-9]*}}>
+// CHECK:             %[[VAL_77:.*]] = sparse_tensor.compress %[[VAL_23]], 
%[[VAL_24]], %[[VAL_25]], %[[VAL_78:.*]]#2 into %[[VAL_21]]{{\[}}%[[VAL_22]]] : 
memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<4x4xf64, #sparse{{[0-9]*}}>
 // CHECK:             scf.yield %[[VAL_77]] : tensor<4x4xf64, 
#sparse{{[0-9]*}}>
 // CHECK:           }
 // CHECK:           %[[VAL_79:.*]] = sparse_tensor.load %[[VAL_80:.*]] 
hasInserts : tensor<4x4xf64, #sparse{{[0-9]*}}>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir 
b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
index 4dff06b8155dd..67d1573058460 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
@@ -216,13 +216,13 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, 
#CSR>) -> tensor<10x20x
 // CHECK:                   %[[VAL_71:.*]] = memref.load 
%[[VAL_19]]{{\[}}%[[VAL_58]]] : memref<?xindex>
 // CHECK:                   %[[VAL_72:.*]] = arith.addi %[[VAL_58]], 
%[[VAL_3]] : index
 // CHECK:                   %[[VAL_73:.*]] = memref.load 
%[[VAL_19]]{{\[}}%[[VAL_72]]] : memref<?xindex>
-// CHECK:                   %[[VAL_74:.*]]:5 = scf.while (%[[VAL_75:.*]] = 
%[[VAL_68]], %[[VAL_76:.*]] = %[[VAL_71]], %[[VAL_77:.*]] = %[[VAL_4]], 
%[[VAL_200:.*]] = %[[VAL_FALSE]], %[[VAL_78:.*]] = %[[VAL_59]]) : (index, 
index, i32, i1, tensor<?x?xi32, #sparse{{[0-9]*}}>) -> (index, index, i32, i1, 
tensor<?x?xi32, #sparse{{[0-9]*}}>) {
+// CHECK:                   %[[VAL_74:.*]]:4 = scf.while (%[[VAL_75:.*]] = 
%[[VAL_68]], %[[VAL_76:.*]] = %[[VAL_71]], %[[VAL_77:.*]] = %[[VAL_4]], 
%[[VAL_200:.*]] = %[[VAL_FALSE]]) : (index, index, i32, i1) -> (index, index, 
i32, i1) {
 // CHECK:                     %[[VAL_79:.*]] = arith.cmpi ult, %[[VAL_75]], 
%[[VAL_70]] : index
 // CHECK:                     %[[VAL_80:.*]] = arith.cmpi ult, %[[VAL_76]], 
%[[VAL_73]] : index
 // CHECK:                     %[[VAL_81:.*]] = arith.andi %[[VAL_79]], 
%[[VAL_80]] : i1
-// CHECK:                     scf.condition(%[[VAL_81]]) %[[VAL_75]], 
%[[VAL_76]], %[[VAL_77]], %[[VAL_200]], %[[VAL_78]] : index, index, i32, i1, 
tensor<?x?xi32, #sparse{{[0-9]*}}>
+// CHECK:                     scf.condition(%[[VAL_81]]) %[[VAL_75]], 
%[[VAL_76]], %[[VAL_77]], %[[VAL_200]] : index, index, i32, i1
 // CHECK:                   } do {
-// CHECK:                   ^bb0(%[[VAL_82:.*]]: index, %[[VAL_83:.*]]: index, 
%[[VAL_84:.*]]: i32, %[[VAL_201:.*]]: i1, %[[VAL_85:.*]]: tensor<?x?xi32, 
#sparse{{[0-9]*}}>):
+// CHECK:                   ^bb0(%[[VAL_82:.*]]: index, %[[VAL_83:.*]]: index, 
%[[VAL_84:.*]]: i32, %[[VAL_201:.*]]: i1):
 // CHECK:                     %[[VAL_86:.*]] = memref.load 
%[[VAL_13]]{{\[}}%[[VAL_82]]] : memref<?xindex>
 // CHECK:                     %[[VAL_87:.*]] = memref.load 
%[[VAL_20]]{{\[}}%[[VAL_83]]] : memref<?xindex>
 // CHECK:                     %[[VAL_88:.*]] = arith.cmpi ult, %[[VAL_87]], 
%[[VAL_86]] : index
@@ -230,14 +230,14 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, 
#CSR>) -> tensor<10x20x
 // CHECK:                     %[[VAL_90:.*]] = arith.cmpi eq, %[[VAL_86]], 
%[[VAL_89]] : index
 // CHECK:                     %[[VAL_91:.*]] = arith.cmpi eq, %[[VAL_87]], 
%[[VAL_89]] : index
 // CHECK:                     %[[VAL_92:.*]] = arith.andi %[[VAL_90]], 
%[[VAL_91]] : i1
-// CHECK:                     %[[VAL_93:.*]]:3 = scf.if %[[VAL_92]] -> (i32, 
i1, tensor<?x?xi32, #sparse{{[0-9]*}}>) {
+// CHECK:                     %[[VAL_93:.*]]:2 = scf.if %[[VAL_92]] -> (i32, 
i1) {
 // CHECK:                       %[[VAL_94:.*]] = memref.load 
%[[VAL_14]]{{\[}}%[[VAL_82]]] : memref<?xi32>
 // CHECK:                       %[[VAL_95:.*]] = memref.load 
%[[VAL_21]]{{\[}}%[[VAL_83]]] : memref<?xi32>
 // CHECK:                       %[[VAL_96:.*]] = arith.muli %[[VAL_94]], 
%[[VAL_95]] : i32
 // CHECK:                       %[[VAL_97:.*]] = arith.addi %[[VAL_84]], 
%[[VAL_96]] : i32
-// CHECK:                       scf.yield %[[VAL_97]], %[[VAL_TRUE]], 
%[[VAL_85]] : i32, i1, tensor<?x?xi32, #sparse{{[0-9]*}}>
+// CHECK:                       scf.yield %[[VAL_97]], %[[VAL_TRUE]] : i32, i1
 // CHECK:                     } else {
-// CHECK:                       scf.yield %[[VAL_84]], %[[VAL_201]], 
%[[VAL_85]] : i32, i1, tensor<?x?xi32, #sparse{{[0-9]*}}>
+// CHECK:                       scf.yield %[[VAL_84]], %[[VAL_201]] : i32, i1
 // CHECK:                     }
 // CHECK:                     %[[VAL_98:.*]] = arith.cmpi eq, %[[VAL_86]], 
%[[VAL_89]] : index
 // CHECK:                     %[[VAL_99:.*]] = arith.addi %[[VAL_82]], 
%[[VAL_3]] : index
@@ -245,13 +245,13 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, 
#CSR>) -> tensor<10x20x
 // CHECK:                     %[[VAL_101:.*]] = arith.cmpi eq, %[[VAL_87]], 
%[[VAL_89]] : index
 // CHECK:                     %[[VAL_102:.*]] = arith.addi %[[VAL_83]], 
%[[VAL_3]] : index
 // CHECK:                     %[[VAL_103:.*]] = arith.select %[[VAL_101]], 
%[[VAL_102]], %[[VAL_83]] : index
-// CHECK:                     scf.yield %[[VAL_100]], %[[VAL_103]], 
%[[VAL_104:.*]]#0, %[[VAL_104]]#1, %[[VAL_104]]#2 : index, index, i32, i1, 
tensor<?x?xi32, #sparse{{[0-9]*}}>
+// CHECK:                     scf.yield %[[VAL_100]], %[[VAL_103]], 
%[[VAL_104:.*]]#0, %[[VAL_104]]#1 : index, index, i32, i1
 // CHECK:                   }
 // CHECK:                   %[[VAL_202:.*]] = scf.if %[[VAL_74]]#3 -> 
(tensor<?x?xi32, #sparse{{[0-9]*}}>) {
-// CHECK:                     %[[VAL_105:.*]] = tensor.insert %[[VAL_74]]#2 
into %[[VAL_74]]#4{{\[}}%[[VAL_39]], %[[VAL_63]]] : tensor<?x?xi32, 
#sparse{{[0-9]*}}>
+// CHECK:                     %[[VAL_105:.*]] = tensor.insert %[[VAL_74]]#2 
into %[[VAL_59]]{{\[}}%[[VAL_39]], %[[VAL_63]]] : tensor<?x?xi32, 
#sparse{{[0-9]*}}>
 // CHECK:                     scf.yield %[[VAL_105]] : tensor<?x?xi32, 
#sparse{{[0-9]*}}>
 // CHECK:                   } else {
-// CHECK:                     scf.yield %[[VAL_74]]#4 : tensor<?x?xi32, 
#sparse{{[0-9]*}}>
+// CHECK:                     scf.yield %[[VAL_59]] : tensor<?x?xi32, 
#sparse{{[0-9]*}}>
 // CHECK:                   }
 // CHECK:                   scf.yield %[[VAL_202]] : tensor<?x?xi32, 
#sparse{{[0-9]*}}>
 // CHECK:                 } else {
@@ -339,13 +339,13 @@ func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>,
 // CHECK:             %[[VAL_31:.*]] = memref.load 
%[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<?xindex>
 // CHECK:             %[[VAL_32:.*]] = memref.load 
%[[VAL_14]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:             %[[VAL_33:.*]] = memref.load 
%[[VAL_14]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK:             %[[VAL_34:.*]]:4 = scf.while (%[[VAL_35:.*]] = 
%[[VAL_29]], %[[VAL_36:.*]] = %[[VAL_32]], %[[VAL_37:.*]] = %[[VAL_28]], 
%[[VAL_38:.*]] = %[[VAL_23]]) : (index, index, index, tensor<?x?xf32, 
#sparse{{[0-9]*}}>) -> (index, index, index, tensor<?x?xf32, 
#sparse{{[0-9]*}}>) {
+// CHECK:             %[[VAL_34:.*]]:3 = scf.while (%[[VAL_35:.*]] = 
%[[VAL_29]], %[[VAL_36:.*]] = %[[VAL_32]], %[[VAL_37:.*]] = %[[VAL_28]]) : 
(index, index, index) -> (index, index, index) {
 // CHECK:               %[[VAL_39:.*]] = arith.cmpi ult, %[[VAL_35]], 
%[[VAL_31]] : index
 // CHECK:               %[[VAL_40:.*]] = arith.cmpi ult, %[[VAL_36]], 
%[[VAL_33]] : index
 // CHECK:               %[[VAL_41:.*]] = arith.andi %[[VAL_39]], %[[VAL_40]] : 
i1
-// CHECK:               scf.condition(%[[VAL_41]]) %[[VAL_35]], %[[VAL_36]], 
%[[VAL_37]], %[[VAL_38]] : index, index, index, tensor<?x?xf32, 
#sparse{{[0-9]*}}>
+// CHECK:               scf.condition(%[[VAL_41]]) %[[VAL_35]], %[[VAL_36]], 
%[[VAL_37]] : index, index, index
 // CHECK:             } do {
-// CHECK:             ^bb0(%[[VAL_42:.*]]: index, %[[VAL_43:.*]]: index, 
%[[VAL_44:.*]]: index, %[[VAL_45:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>):
+// CHECK:             ^bb0(%[[VAL_42:.*]]: index, %[[VAL_43:.*]]: index, 
%[[VAL_44:.*]]: index):
 // CHECK:               %[[VAL_46:.*]] = memref.load 
%[[VAL_12]]{{\[}}%[[VAL_42]]] : memref<?xindex>
 // CHECK:               %[[VAL_47:.*]] = memref.load 
%[[VAL_15]]{{\[}}%[[VAL_43]]] : memref<?xindex>
 // CHECK:               %[[VAL_48:.*]] = arith.cmpi ult, %[[VAL_47]], 
%[[VAL_46]] : index
@@ -353,7 +353,7 @@ func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>,
 // CHECK:               %[[VAL_50:.*]] = arith.cmpi eq, %[[VAL_46]], 
%[[VAL_49]] : index
 // CHECK:               %[[VAL_51:.*]] = arith.cmpi eq, %[[VAL_47]], 
%[[VAL_49]] : index
 // CHECK:               %[[VAL_52:.*]] = arith.andi %[[VAL_50]], %[[VAL_51]] : 
i1
-// CHECK:               %[[VAL_53:.*]]:2 = scf.if %[[VAL_52]] -> (index, 
tensor<?x?xf32, #sparse{{[0-9]*}}>) {
+// CHECK:               %[[VAL_53:.*]] = scf.if %[[VAL_52]] -> (index) {
 // CHECK:                 %[[VAL_54:.*]] = memref.load 
%[[VAL_13]]{{\[}}%[[VAL_42]]] : memref<?xf32>
 // CHECK:                 %[[VAL_55:.*]] = memref.load 
%[[VAL_16]]{{\[}}%[[VAL_43]]] : memref<?xindex>
 // CHECK:                 %[[VAL_56:.*]] = arith.addi %[[VAL_43]], %[[VAL_3]] 
: index
@@ -377,9 +377,9 @@ func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>,
 // CHECK:                   memref.store %[[VAL_65]], 
%[[VAL_25]]{{\[}}%[[VAL_61]]] : memref<?xf32>
 // CHECK:                   scf.yield %[[VAL_70:.*]] : index
 // CHECK:                 }
-// CHECK:                 scf.yield %[[VAL_71:.*]], %[[VAL_45]] : index, 
tensor<?x?xf32, #sparse{{[0-9]*}}>
+// CHECK:                 scf.yield %[[VAL_71:.*]] : index
 // CHECK:               } else {
-// CHECK:                 scf.yield %[[VAL_44]], %[[VAL_45]] : index, 
tensor<?x?xf32, #sparse{{[0-9]*}}>
+// CHECK:                 scf.yield %[[VAL_44]] : index
 // CHECK:               }
 // CHECK:               %[[VAL_72:.*]] = arith.cmpi eq, %[[VAL_46]], 
%[[VAL_49]] : index
 // CHECK:               %[[VAL_73:.*]] = arith.addi %[[VAL_42]], %[[VAL_3]] : 
index
@@ -387,9 +387,9 @@ func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>,
 // CHECK:               %[[VAL_75:.*]] = arith.cmpi eq, %[[VAL_47]], 
%[[VAL_49]] : index
 // CHECK:               %[[VAL_76:.*]] = arith.addi %[[VAL_43]], %[[VAL_3]] : 
index
 // CHECK:               %[[VAL_77:.*]] = arith.select %[[VAL_75]], 
%[[VAL_76]], %[[VAL_43]] : index
-// CHECK:               scf.yield %[[VAL_74]], %[[VAL_77]], %[[VAL_78:.*]]#0, 
%[[VAL_78]]#1 : index, index, index, tensor<?x?xf32, #sparse{{[0-9]*}}>
+// CHECK:               scf.yield %[[VAL_74]], %[[VAL_77]], %[[VAL_78:.*]] : 
index, index, index
 // CHECK:             }
-// CHECK:             %[[VAL_79:.*]] = sparse_tensor.compress %[[VAL_25]], 
%[[VAL_26]], %[[VAL_27]], %[[VAL_80:.*]]#2 into %[[VAL_80]]#3{{\[}}%[[VAL_24]]] 
: memref<?xf32>, memref<?xi1>, memref<?xindex>, tensor<?x?xf32, 
#sparse{{[0-9]*}}>
+// CHECK:             %[[VAL_79:.*]] = sparse_tensor.compress %[[VAL_25]], 
%[[VAL_26]], %[[VAL_27]], %[[VAL_80:.*]]#2 into %[[VAL_23]]{{\[}}%[[VAL_24]]] : 
memref<?xf32>, memref<?xi1>, memref<?xindex>, tensor<?x?xf32, #sparse{{[0-9]*}}>
 // CHECK:             scf.yield %[[VAL_79]] : tensor<?x?xf32, 
#sparse{{[0-9]*}}>
 // CHECK:           }
 // CHECK:           %[[VAL_81:.*]] = sparse_tensor.load %[[VAL_82:.*]] 
hasInserts : tensor<?x?xf32, #sparse{{[0-9]*}}>
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir 
b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 135db02d543ef..18fb6852f6875 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1330,11 +1330,11 @@ func.func @vector_insert_1d_broadcast(%laneid: index, 
%pos: index) -> (vector<96
 // -----
 
 // CHECK-PROP-LABEL: func @vector_insert_0d(
-//       CHECK-PROP:   %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> 
(vector<f32>, f32)
+//       CHECK-PROP:   %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (f32)
 //       CHECK-PROP:     %[[VEC:.*]] = "some_def"
 //       CHECK-PROP:     %[[VAL:.*]] = "another_def"
-//       CHECK-PROP:     gpu.yield %[[VEC]], %[[VAL]]
-//       CHECK-PROP:   vector.broadcast %[[W]]#1 : f32 to vector<f32>
+//       CHECK-PROP:     gpu.yield %[[VAL]]
+//       CHECK-PROP:   vector.broadcast %[[W]] : f32 to vector<f32>
 func.func @vector_insert_0d(%laneid: index) -> (vector<f32>) {
   %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<f32>) {
     %0 = "some_def"() : () -> (vector<f32>)
diff --git a/mlir/test/Transforms/remove-dead-values.mlir 
b/mlir/test/Transforms/remove-dead-values.mlir
index b9a883dbd524e..5bf5487974d35 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -250,13 +250,13 @@ func.func @main() -> (i32, i32) {
 // CHECK-NEXT:  }
 
 // CHECK-CANONICALIZE:       func.func 
@clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]:
 i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
-// CHECK-CANONICALIZE-NEXT:    %[[live_and_non_live:.*]]:2 = scf.while 
(%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) {
+// CHECK-CANONICALIZE:         %[[live_and_non_live:.*]]:2 = scf.while 
(%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) {
 // CHECK-CANONICALIZE-NEXT:      %[[live_0:.*]] = arith.addi %[[arg4]], 
%[[arg4]]
-// CHECK-CANONICALIZE-NEXT:      scf.condition(%arg0) %[[live_0]], %[[arg4]] : 
i32, i32
+// CHECK-CANONICALIZE:           scf.condition(%arg0) %[[live_0]], %[[arg4]] : 
i32, i32
 // CHECK-CANONICALIZE-NEXT:    } do {
 // CHECK-CANONICALIZE-NEXT:    ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
 // CHECK-CANONICALIZE-NEXT:      %[[live_1:.*]] = arith.addi %[[arg6]], 
%[[arg6]]
-// CHECK-CANONICALIZE-NEXT:      scf.yield %[[live_1]] : i32
+// CHECK-CANONICALIZE:           scf.yield %[[live_1]] : i32
 // CHECK-CANONICALIZE-NEXT:    }
 // CHECK-CANONICALIZE-NEXT:    return %[[live_and_non_live]]#0
 // CHECK-CANONICALIZE-NEXT:  }
@@ -306,7 +306,7 @@ func.func 
@clean_region_branch_op_dont_remove_first_2_results_but_remove_first_o
 // CHECK-CANONICALIZE:       func.func 
@clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]:
 i1) -> i32 {
 // CHECK-CANONICALIZE-NEXT:    %[[c0:.*]] = arith.constant 0
 // CHECK-CANONICALIZE-NEXT:    %[[c1:.*]] = arith.constant 1
-// CHECK-CANONICALIZE-NEXT:    %[[live_and_non_live:.*]]:2 = scf.while 
(%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
+// CHECK-CANONICALIZE:         %[[live_and_non_live:.*]]:2 = scf.while 
(%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
 // CHECK-CANONICALIZE-NEXT:      func.call @identity() : () -> ()
 // CHECK-CANONICALIZE-NEXT:      scf.condition(%[[arg2]]) %[[arg3]], %[[arg4]] 
: i32, i32
 // CHECK-CANONICALIZE-NEXT:    } do {

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to