https://github.com/matthias-springer updated 
https://github.com/llvm/llvm-project/pull/174094

>From da0a853c603fd3acaddf502d6520c70376c77481 Mon Sep 17 00:00:00 2001
From: Matthias Springer <[email protected]>
Date: Wed, 31 Dec 2025 14:07:51 +0000
Subject: [PATCH 1/2] [mlir][draft] Consolidate patterns into
 RegionBranchOpInterface patterns

---
 .../mlir/Interfaces/ControlFlowInterfaces.h   |   2 +
 .../mlir/Interfaces/ControlFlowInterfaces.td  |   9 +
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 908 ++++--------------
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp |  39 +
 mlir/test/Dialect/SCF/canonicalize.mlir       |  12 +-
 5 files changed, 243 insertions(+), 727 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h 
b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 566f4b8fadb5d..a7565f9f7bb78 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -188,6 +188,8 @@ LogicalResult verifyTypesAlongControlFlowEdges(Operation 
*op);
 /// possible successors.) Operands that not forwarded at all are not present in
 /// the mapping.
 using RegionBranchSuccessorMapping = DenseMap<OpOperand *, SmallVector<Value>>;
+using RegionBranchInverseSuccessorMapping =
+    DenseMap<Value, SmallVector<OpOperand *>>;
 
 /// This class represents a successor of a region. A region successor can 
either
 /// be another region, or the parent operation. If the successor is a region,
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td 
b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 2e654ba04ffe5..9366e5562b774 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -355,6 +355,15 @@ def RegionBranchOpInterface : 
OpInterface<"RegionBranchOpInterface"> {
         ::mlir::RegionBranchSuccessorMapping &mapping,
         std::optional<::mlir::RegionBranchPoint> src = std::nullopt);
 
+    /// Build a mapping from successor inputs to successor operands. This is
+    /// the same as "getSuccessorOperandInputMapping", but inverted.
+    void getSuccessorInputOperandMapping(
+        ::mlir::RegionBranchInverseSuccessorMapping &mapping);
+
+    /// Compute all values that a successor input could possibly have. If the
+    /// given value is not a successor input, return an empty set.
+    ::llvm::DenseSet<Value> 
computePossibleValuesOfSuccessorInput(::mlir::Value value);
+
     /// Return all possible region branch points: the region branch op itself
     /// and all region branch terminators.
     ::llvm::SmallVector<::mlir::RegionBranchPoint> getAllRegionBranchPoints();
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 46d09abd89d69..6e1538676b1e5 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -27,6 +27,7 @@
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/EquivalenceClasses.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallPtrSet.h"
@@ -291,102 +292,9 @@ struct MultiBlockExecuteInliner : public 
OpRewritePattern<ExecuteRegionOp> {
   }
 };
 
-// Pattern to eliminate ExecuteRegionOp results which forward external
-// values from the region. In case there are multiple yield operations,
-// all of them must have the same operands in order for the pattern to be
-// applicable.
-struct ExecuteRegionForwardingEliminator
-    : public OpRewritePattern<ExecuteRegionOp> {
-  using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExecuteRegionOp op,
-                                PatternRewriter &rewriter) const override {
-    if (op.getNumResults() == 0)
-      return failure();
-
-    SmallVector<Operation *> yieldOps;
-    for (Block &block : op.getRegion()) {
-      if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator()))
-        yieldOps.push_back(yield.getOperation());
-    }
-
-    if (yieldOps.empty())
-      return failure();
-
-    // Check if all yield operations have the same operands.
-    auto yieldOpsOperands = yieldOps[0]->getOperands();
-    for (auto *yieldOp : yieldOps) {
-      if (yieldOp->getOperands() != yieldOpsOperands)
-        return failure();
-    }
-
-    SmallVector<Value> externalValues;
-    SmallVector<Value> internalValues;
-    SmallVector<Value> opResultsToReplaceWithExternalValues;
-    SmallVector<Value> opResultsToKeep;
-    for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
-      if (isValueFromInsideRegion(yieldedValue, op)) {
-        internalValues.push_back(yieldedValue);
-        opResultsToKeep.push_back(op.getResult(index));
-      } else {
-        externalValues.push_back(yieldedValue);
-        opResultsToReplaceWithExternalValues.push_back(op.getResult(index));
-      }
-    }
-    // No yielded external values - nothing to do.
-    if (externalValues.empty())
-      return failure();
-
-    // There are yielded external values - create a new execute_region 
returning
-    // just the internal values.
-    SmallVector<Type> resultTypes;
-    for (Value value : internalValues)
-      resultTypes.push_back(value.getType());
-    auto newOp =
-        ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes));
-    newOp->setAttrs(op->getAttrs());
-
-    // Move old op's region to the new operation.
-    rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
-                                newOp.getRegion().end());
-
-    // Replace all yield operations with a new yield operation with updated
-    // results. scf.execute_region must have at least one yield operation.
-    for (auto *yieldOp : yieldOps) {
-      rewriter.setInsertionPoint(yieldOp);
-      rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp,
-                                                ValueRange(internalValues));
-    }
-
-    // Replace the old operation with the external values directly.
-    rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues,
-                                externalValues);
-    // Replace the old operation's remaining results with the new operation's
-    // results.
-    rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults());
-    rewriter.eraseOp(op);
-    return success();
-  }
-
-private:
-  bool isValueFromInsideRegion(Value value,
-                               ExecuteRegionOp executeRegionOp) const {
-    // Check if the value is defined within the execute_region
-    if (Operation *defOp = value.getDefiningOp())
-      return &executeRegionOp.getRegion() == defOp->getParentRegion();
-
-    // If it's a block argument, check if it's from within the region
-    if (BlockArgument blockArg = dyn_cast<BlockArgument>(value))
-      return &executeRegionOp.getRegion() == blockArg.getParentRegion();
-
-    return false; // Value is from outside the region
-  }
-};
-
 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner,
-              ExecuteRegionForwardingEliminator>(context);
+  results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
 }
 
 void ExecuteRegionOp::getSuccessorRegions(
@@ -1234,91 +1142,199 @@ struct ForOpTensorCastFolder : public 
OpRewritePattern<ForOp> {
   }
 };
 
-/// Rewriting pattern that folds away cycles in the yield of a scf.for op.
-///
-/// ```
-/// %res:2 = scf.for ... iter_args(%arg0 = %init, %arg1 = %init) {
-///   ...
-///   use %arg0, %arg1
-///   scf.yield %arg1, %arg0
-/// }
-/// return %res#0, %res#1
-/// ```
-///
-/// folds into:
-///
-/// ```
-/// scf.for ... iter_args() {
-///   ...
-///   use %init, %init
-///   scf.yield
-/// }
-/// return %init, %init
-/// ```
-struct ForOpYieldCyclesFolder : public OpRewritePattern<ForOp> {
-  using Base::Base;
+/// Is a defined before b?
+static bool isDefinedBefore(Value a, Value b) {
+  Region *aRegion = a.getParentRegion();
+  Region *bRegion = b.getParentRegion();
 
-  LogicalResult matchAndRewrite(ForOp op,
+  if (aRegion->getParentOp()->isProperAncestor(bRegion->getParentOp())) {
+    return true;
+  }
+  if (aRegion == bRegion) {
+    Block *aBlock = a.getParentBlock();
+    Block *bBlock = b.getParentBlock();
+    if (aBlock != bBlock)
+      return false;
+    if (isa<BlockArgument>(a))
+      return true;
+    if (isa<BlockArgument>(b))
+      return false;
+    return a.getDefiningOp()->isBeforeInBlock(b.getDefiningOp());
+  }
+
+  return false;
+}
+
+// Try to make successor inputs dead by replacing their uses with values that
+// are not successor inputs. This pattern enables additional canonicalization
+// opportunities for RemoveDeadValues.
+struct RemoveUsesOfIdenticalValues
+    : public OpInterfaceRewritePattern<RegionBranchOpInterface> {
+  using OpInterfaceRewritePattern<
+      RegionBranchOpInterface>::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(RegionBranchOpInterface op,
                                 PatternRewriter &rewriter) const override {
-    ValueRange yieldedValues = op.getYieldedValues();
-    ValueRange initArgs = op.getInitArgs();
-    ValueRange results = op.getResults();
-    ValueRange regionIterArgs = op.getRegionIterArgs();
-    Block *body = op.getBody();
+    // TODO: ForallOp data flow is modeled incompletely.
+    if (isa<ForallOp>(op))
+      return failure();
 
-    unsigned numYieldedValues = op.getNumRegionIterArgs();
+    // Gather all potential successor inputs. (Other values may also be
+    // included, but we're not doing anything with them.)
+    SmallVector<Value> values;
+    llvm::append_range(values, op->getResults());
+    for (Region &r : op->getRegions())
+      llvm::append_range(values, r.getArguments());
 
     bool changed = false;
-    SmallVector<unsigned> cycle;
-    llvm::SmallBitVector visited(numYieldedValues, false);
+    for (Value value : values) {
+      if (value.use_empty())
+        continue;
+      DenseSet<Value> possibleValues =
+          op.computePossibleValuesOfSuccessorInput(value);
+      if (possibleValues.size() == 1 && *possibleValues.begin() != value &&
+          isDefinedBefore(*possibleValues.begin(), value)) {
+        // Value is same as another value.
+        rewriter.replaceAllUsesWith(value, *possibleValues.begin());
+        changed = true;
+      }
+    }
+    return success(changed);
+  }
+};
 
-    // Go through all possible start points for the cycle.
-    for (auto start : llvm::seq(numYieldedValues)) {
-      if (visited[start])
+/// Pattern to remove dead values from region branch ops.
+struct RemoveDeadValues
+    : public OpInterfaceRewritePattern<RegionBranchOpInterface> {
+  using OpInterfaceRewritePattern<
+      RegionBranchOpInterface>::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(RegionBranchOpInterface op,
+                                PatternRewriter &rewriter) const override {
+    // TODO: ForallOp data flow is modeled incompletely.
+    if (isa<ForallOp>(op))
+      return failure();
+
+    // Compute tied values: values that must come as a set. If you remove one,
+    // you must remove all.
+    RegionBranchSuccessorMapping operandToInputs;
+    op.getSuccessorOperandInputMapping(operandToInputs);
+    llvm::EquivalenceClasses<Value> tiedSuccessorInputs;
+    for (const auto &[operand, inputs] : operandToInputs) {
+      assert(!inputs.empty() && "expected non-empty inputs");
+      Value firstInput = inputs.front();
+      tiedSuccessorInputs.insert(firstInput);
+      for (Value nextInput : llvm::drop_begin(inputs))
+        tiedSuccessorInputs.unionSets(firstInput, nextInput);
+    }
+
+    // Determine which values to remove and group them by block and operation.
+    SmallVector<Value> valuesToRemove;
+    DenseMap<Block *, BitVector> blockArgsToRemove;
+    DenseMap<Operation *, BitVector> resultsToRemove;
+    for (auto it = tiedSuccessorInputs.begin(), e = tiedSuccessorInputs.end();
+         it != e; ++it) {
+      if (!(*it)->isLeader())
         continue;
 
-      cycle.clear();
-      unsigned current = start;
-      bool validCycle = true;
-      Value initValue = initArgs[start];
-      // Go through yield -> block arg -> yield cycles and check if all values
-      // are always equal to the init.
-      while (!visited[current]) {
-        cycle.push_back(current);
-        visited[current] = true;
-
-        // Find whether this yield is from a region iter arg.
-        auto yieldedValue = yieldedValues[current];
-        if (auto arg = dyn_cast<BlockArgument>(yieldedValue);
-            !arg || arg.getOwner() != body) {
-          validCycle = false;
+      // Value can be removed if it is dead and all other tied values are also
+      // dead.
+      bool allDead = true;
+      for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
+           memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
+        if (!memberIt->use_empty()) {
+          allDead = false;
           break;
         }
+      }
+      if (!allDead)
+        continue;
 
-        // Next yield position.
-        current = cast<BlockArgument>(yieldedValue).getArgNumber() -
-                  op.getNumInductionVars();
-
-        // Check if next position has the same init value.
-        if (initArgs[current] != initValue) {
-          validCycle = false;
-          break;
+      // Group values by block and operation.
+      for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
+           memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
+        if (auto arg = dyn_cast<BlockArgument>(*memberIt)) {
+          BitVector &vector =
+              blockArgsToRemove
+                  .try_emplace(arg.getOwner(),
+                               arg.getOwner()->getNumArguments(), false)
+                  .first->second;
+          vector.set(arg.getArgNumber());
+        } else {
+          OpResult result = cast<OpResult>(*memberIt);
+          BitVector &vector =
+              resultsToRemove
+                  .try_emplace(result.getDefiningOp(),
+                               result.getDefiningOp()->getNumResults(), false)
+                  .first->second;
+          vector.set(result.getResultNumber());
         }
+        valuesToRemove.push_back(*memberIt);
       }
+    }
 
-      // If we found a valid cycle (yielding own iter arg forms cycle of length
-      // 1), all values in it are always equal to initValue.
-      if (validCycle) {
-        changed = true;
-        for (unsigned idx : cycle) {
-          // This will leave region args and results dead so other
-          // canonicalization patterns can clean them up.
-          rewriter.replaceAllUsesWith(regionIterArgs[idx], initValue);
-          rewriter.replaceAllUsesWith(results[idx], initValue);
+    if (valuesToRemove.empty())
+      return rewriter.notifyMatchFailure(op, "no values to remove");
+
+    // Find operands that must be removed together with the values.
+    RegionBranchInverseSuccessorMapping inputsToOperands;
+    op.getSuccessorInputOperandMapping(inputsToOperands);
+    DenseMap<Operation *, llvm::BitVector> operandsToRemove;
+    for (Value value : valuesToRemove) {
+      for (OpOperand *operand : inputsToOperands[value]) {
+        BitVector &vector =
+            operandsToRemove
+                .try_emplace(operand->getOwner(),
+                             operand->getOwner()->getNumOperands(), false)
+                .first->second;
+        vector.set(operand->getOperandNumber());
+      }
+    }
+
+    // Erase operands.
+    for (auto [op, operands] : operandsToRemove) {
+      rewriter.modifyOpInPlace(op, [&]() { op->eraseOperands(operands); });
+    }
+
+    // Erase block arguments.
+    for (auto [block, blockArgs] : blockArgsToRemove) {
+      rewriter.modifyOpInPlace(block->getParentOp(),
+                               [&]() { block->eraseArguments(blockArgs); });
+    }
+
+    // Erase op results.
+    // TODO: Can we move this to RewriterBase, so we have a uniform API,
+    // similar to eraseArguments?
+    for (auto [op, resultsToErase] : resultsToRemove) {
+      rewriter.setInsertionPoint(op);
+      SmallVector<Type> newResultTypes;
+      for (OpResult result : op->getResults())
+        if (!resultsToErase[result.getResultNumber()])
+          newResultTypes.push_back(result.getType());
+      OperationState state(op->getLoc(), op->getName().getStringRef(),
+                           op->getOperands(), newResultTypes, op->getAttrs());
+      for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
+        state.addRegion();
+      Operation *newOp = rewriter.create(state);
+      for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
+        // Move all blocks of `region` into `newRegion`.
+        Region &newRegion = newOp->getRegion(index);
+        rewriter.inlineRegionBefore(region, newRegion, newRegion.begin());
+      }
+
+      SmallVector<Value> newResults;
+      unsigned nextLiveResult = 0;
+      for (auto [index, result] : llvm::enumerate(op->getResults())) {
+        if (!resultsToErase[index]) {
+          newResults.push_back(newOp->getResult(nextLiveResult++));
+        } else {
+          newResults.push_back(Value());
         }
       }
+      rewriter.replaceOp(op, newResults);
     }
-    return success(changed);
+
+    return success();
   }
 };
 
@@ -1326,8 +1342,11 @@ struct ForOpYieldCyclesFolder : public 
OpRewritePattern<ForOp> {
 
 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder,
-              ForOpYieldCyclesFolder>(context);
+  // TODO: ForOpIterArgsFolder also removes duplicate bbargs, can this be done
+  // in the two new patterns?
+  results.add</*ForOpIterArgsFolder, */ SimplifyTrivialLoops,
+              ForOpTensorCastFolder, RemoveUsesOfIdenticalValues,
+              RemoveDeadValues>(context);
 }
 
 std::optional<APInt> ForOp::getConstantStep() {
@@ -2495,61 +2514,6 @@ void IfOp::getRegionInvocationBounds(
 }
 
 namespace {
-// Pattern to remove unused IfOp results.
-struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
-  using OpRewritePattern<IfOp>::OpRewritePattern;
-
-  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
-                    PatternRewriter &rewriter) const {
-    // Move all operations to the destination block.
-    rewriter.mergeBlocks(source, dest);
-    // Replace the yield op by one that returns only the used values.
-    auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
-    SmallVector<Value, 4> usedOperands;
-    llvm::transform(usedResults, std::back_inserter(usedOperands),
-                    [&](OpResult result) {
-                      return yieldOp.getOperand(result.getResultNumber());
-                    });
-    rewriter.modifyOpInPlace(yieldOp,
-                             [&]() { yieldOp->setOperands(usedOperands); });
-  }
-
-  LogicalResult matchAndRewrite(IfOp op,
-                                PatternRewriter &rewriter) const override {
-    // Compute the list of used results.
-    SmallVector<OpResult, 4> usedResults;
-    llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
-                  [](OpResult result) { return !result.use_empty(); });
-
-    // Replace the operation if only a subset of its results have uses.
-    if (usedResults.size() == op.getNumResults())
-      return failure();
-
-    // Compute the result types of the replacement operation.
-    SmallVector<Type, 4> newTypes;
-    llvm::transform(usedResults, std::back_inserter(newTypes),
-                    [](OpResult result) { return result.getType(); });
-
-    // Create a replacement operation with empty then and else regions.
-    auto newOp =
-        IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
-    rewriter.createBlock(&newOp.getThenRegion());
-    rewriter.createBlock(&newOp.getElseRegion());
-
-    // Move the bodies and replace the terminators (note there is a then and
-    // an else region since the operation returns results).
-    transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
-    transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
-
-    // Replace the operation by the new one.
-    SmallVector<Value, 4> repResults(op.getNumResults());
-    for (const auto &en : llvm::enumerate(usedResults))
-      repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
-    rewriter.replaceOp(op, repResults);
-    return success();
-  }
-};
-
 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
   using OpRewritePattern<IfOp>::OpRewritePattern;
 
@@ -3120,8 +3084,8 @@ void IfOp::getCanonicalizationPatterns(RewritePatternSet 
&results,
                                        MLIRContext *context) {
   results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
               ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
-              RemoveStaticCondition, RemoveUnusedResults,
-              ReplaceIfYieldWithConditionOrValue>(context);
+              RemoveStaticCondition, ReplaceIfYieldWithConditionOrValue>(
+      context);
 }
 
 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
@@ -3959,390 +3923,6 @@ struct WhileConditionTruth : public 
OpRewritePattern<WhileOp> {
   }
 };
 
-/// Remove loop invariant arguments from `before` block of scf.while.
-/// A before block argument is considered loop invariant if :-
-///   1. i-th yield operand is equal to the i-th while operand.
-///   2. i-th yield operand is k-th after block argument which is (k+1)-th
-///      condition operand AND this (k+1)-th condition operand is equal to i-th
-///      iter argument/while operand.
-/// For the arguments which are removed, their uses inside scf.while
-/// are replaced with their corresponding initial value.
-///
-/// Eg:
-///    INPUT :-
-///    %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
-///                                     ..., %argN_before = %N)
-///           {
-///                ...
-///                scf.condition(%cond) %arg1_before, %arg0_before,
-///                                     %arg2_before, %arg0_before, ...
-///           } do {
-///             ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
-///                  ..., %argK_after):
-///                ...
-///                scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
-///           }
-///
-///    OUTPUT :-
-///    %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
-///                                     %N)
-///           {
-///                ...
-///                scf.condition(%cond) %b, %a, %arg2_before, %a, ...
-///           } do {
-///             ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
-///                  ..., %argK_after):
-///                ...
-///                scf.yield %arg1_after, ..., %argN
-///           }
-///
-///    EXPLANATION:
-///      We iterate over each yield operand.
-///        1. 0-th yield operand %arg0_after_2 is 4-th condition operand
-///           %arg0_before, which in turn is the 0-th iter argument. So we
-///           remove 0-th before block argument and yield operand, and replace
-///           all uses of the 0-th before block argument with its initial value
-///           %a.
-///        2. 1-th yield operand %b is equal to the 1-th iter arg's initial
-///           value. So we remove this operand and the corresponding before
-///           block argument and replace all uses of 1-th before block argument
-///           with %b.
-struct RemoveLoopInvariantArgsFromBeforeBlock
-    : public OpRewritePattern<WhileOp> {
-  using OpRewritePattern<WhileOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(WhileOp op,
-                                PatternRewriter &rewriter) const override {
-    Block &afterBlock = *op.getAfterBody();
-    Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
-    ConditionOp condOp = op.getConditionOp();
-    OperandRange condOpArgs = condOp.getArgs();
-    Operation *yieldOp = afterBlock.getTerminator();
-    ValueRange yieldOpArgs = yieldOp->getOperands();
-
-    bool canSimplify = false;
-    for (const auto &it :
-         llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
-      auto index = static_cast<unsigned>(it.index());
-      auto [initVal, yieldOpArg] = it.value();
-      // If i-th yield operand is equal to the i-th operand of the scf.while,
-      // the i-th before block argument is a loop invariant.
-      if (yieldOpArg == initVal) {
-        canSimplify = true;
-        break;
-      }
-      // If the i-th yield operand is k-th after block argument, then we check
-      // if the (k+1)-th condition op operand is equal to either the i-th 
before
-      // block argument or the initial value of i-th before block argument. If
-      // the comparison results `true`, i-th before block argument is a loop
-      // invariant.
-      auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
-      if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
-        Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
-        if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
-          canSimplify = true;
-          break;
-        }
-      }
-    }
-
-    if (!canSimplify)
-      return failure();
-
-    SmallVector<Value> newInitArgs, newYieldOpArgs;
-    DenseMap<unsigned, Value> beforeBlockInitValMap;
-    SmallVector<Location> newBeforeBlockArgLocs;
-    for (const auto &it :
-         llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
-      auto index = static_cast<unsigned>(it.index());
-      auto [initVal, yieldOpArg] = it.value();
-
-      // If i-th yield operand is equal to the i-th operand of the scf.while,
-      // the i-th before block argument is a loop invariant.
-      if (yieldOpArg == initVal) {
-        beforeBlockInitValMap.insert({index, initVal});
-        continue;
-      } else {
-        // If the i-th yield operand is k-th after block argument, then we 
check
-        // if the (k+1)-th condition op operand is equal to either the i-th
-        // before block argument or the initial value of i-th before block
-        // argument. If the comparison results `true`, i-th before block
-        // argument is a loop invariant.
-        auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
-        if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
-          Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
-          if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
-            beforeBlockInitValMap.insert({index, initVal});
-            continue;
-          }
-        }
-      }
-      newInitArgs.emplace_back(initVal);
-      newYieldOpArgs.emplace_back(yieldOpArg);
-      newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
-    }
-
-    {
-      OpBuilder::InsertionGuard g(rewriter);
-      rewriter.setInsertionPoint(yieldOp);
-      rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
-    }
-
-    auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(),
-                                    newInitArgs);
-
-    Block &newBeforeBlock = *rewriter.createBlock(
-        &newWhile.getBefore(), /*insertPt*/ {},
-        ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
-
-    Block &beforeBlock = *op.getBeforeBody();
-    SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
-    // For each i-th before block argument we find it's replacement value as :-
-    //   1. If i-th before block argument is a loop invariant, we fetch it's
-    //      initial value from `beforeBlockInitValMap` by querying for key `i`.
-    //   2. Else we fetch j-th new before block argument as the replacement
-    //      value of i-th before block argument.
-    for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) 
{
-      // If the index 'i' argument was a loop invariant we fetch it's initial
-      // value from `beforeBlockInitValMap`.
-      if (beforeBlockInitValMap.count(i) != 0)
-        newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
-      else
-        newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
-    }
-
-    rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
-    rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
-                                newWhile.getAfter().begin());
-
-    rewriter.replaceOp(op, newWhile.getResults());
-    return success();
-  }
-};
-
-/// Remove loop invariant value from result (condition op) of scf.while.
-/// A value is considered loop invariant if the final value yielded by
-/// scf.condition is defined outside of the `before` block. We remove the
-/// corresponding argument in `after` block and replace the use with the value.
-/// We also replace the use of the corresponding result of scf.while with the
-/// value.
-///
-/// Eg:
-///    INPUT :-
-///    %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
-///                                             %argN_before = %N) {
-///                ...
-///                scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
-///           } do {
-///             ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
-///                ...
-///                some_func(%arg1_after)
-///                ...
-///                scf.yield %arg0_after, %arg2_after, ..., %argN_after
-///           }
-///
-///    OUTPUT :-
-///    %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
-///                ...
-///                scf.condition(%cond) %arg0, %arg1, ..., %argM
-///           } do {
-///             ^bb0(%arg0, %arg3, ..., %argM):
-///                ...
-///                some_func(%a)
-///                ...
-///                scf.yield %arg0, %b, ..., %argN
-///           }
-///
-///     EXPLANATION:
-///       1. The 1-th and 2-th operand of scf.condition are defined outside the
-///          before block of scf.while, so they get removed.
-///       2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
-///          replaced by %b.
-///       3. The corresponding after block argument %arg1_after's uses are
-///          replaced by %a and %arg2_after's uses are replaced by %b.
-struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
-  using OpRewritePattern<WhileOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(WhileOp op,
-                                PatternRewriter &rewriter) const override {
-    Block &beforeBlock = *op.getBeforeBody();
-    ConditionOp condOp = op.getConditionOp();
-    OperandRange condOpArgs = condOp.getArgs();
-
-    bool canSimplify = false;
-    for (Value condOpArg : condOpArgs) {
-      // Those values not defined within `before` block will be considered as
-      // loop invariant values. We map the corresponding `index` with their
-      // value.
-      if (condOpArg.getParentBlock() != &beforeBlock) {
-        canSimplify = true;
-        break;
-      }
-    }
-
-    if (!canSimplify)
-      return failure();
-
-    Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
-
-    SmallVector<Value> newCondOpArgs;
-    SmallVector<Type> newAfterBlockType;
-    DenseMap<unsigned, Value> condOpInitValMap;
-    SmallVector<Location> newAfterBlockArgLocs;
-    for (const auto &it : llvm::enumerate(condOpArgs)) {
-      auto index = static_cast<unsigned>(it.index());
-      Value condOpArg = it.value();
-      // Those values not defined within `before` block will be considered as
-      // loop invariant values. We map the corresponding `index` with their
-      // value.
-      if (condOpArg.getParentBlock() != &beforeBlock) {
-        condOpInitValMap.insert({index, condOpArg});
-      } else {
-        newCondOpArgs.emplace_back(condOpArg);
-        newAfterBlockType.emplace_back(condOpArg.getType());
-        newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
-      }
-    }
-
-    {
-      OpBuilder::InsertionGuard g(rewriter);
-      rewriter.setInsertionPoint(condOp);
-      rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
-                                               newCondOpArgs);
-    }
-
-    auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType,
-                                    op.getOperands());
-
-    Block &newAfterBlock =
-        *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
-                              newAfterBlockType, newAfterBlockArgLocs);
-
-    Block &afterBlock = *op.getAfterBody();
-    // Since a new scf.condition op was created, we need to fetch the new
-    // `after` block arguments which will be used while replacing operations of
-    // previous scf.while's `after` blocks. We'd also be fetching new result
-    // values too.
-    SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
-    SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
-    for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
-      Value afterBlockArg, result;
-      // If index 'i' argument was loop invariant we fetch it's value from the
-      // `condOpInitMap` map.
-      if (condOpInitValMap.count(i) != 0) {
-        afterBlockArg = condOpInitValMap[i];
-        result = afterBlockArg;
-      } else {
-        afterBlockArg = newAfterBlock.getArgument(j);
-        result = newWhile.getResult(j);
-        j++;
-      }
-      newAfterBlockArgs[i] = afterBlockArg;
-      newWhileResults[i] = result;
-    }
-
-    rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
-    rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
-                                newWhile.getBefore().begin());
-
-    rewriter.replaceOp(op, newWhileResults);
-    return success();
-  }
-};
-
-/// Remove WhileOp results that are also unused in 'after' block.
-///
-///  %0:2 = scf.while () : () -> (i32, i64) {
-///    %condition = "test.condition"() : () -> i1
-///    %v1 = "test.get_some_value"() : () -> i32
-///    %v2 = "test.get_some_value"() : () -> i64
-///    scf.condition(%condition) %v1, %v2 : i32, i64
-///  } do {
-///  ^bb0(%arg0: i32, %arg1: i64):
-///    "test.use"(%arg0) : (i32) -> ()
-///    scf.yield
-///  }
-///  return %0#0 : i32
-///
-/// becomes
-///  %0 = scf.while () : () -> (i32) {
-///    %condition = "test.condition"() : () -> i1
-///    %v1 = "test.get_some_value"() : () -> i32
-///    %v2 = "test.get_some_value"() : () -> i64
-///    scf.condition(%condition) %v1 : i32
-///  } do {
-///  ^bb0(%arg0: i32):
-///    "test.use"(%arg0) : (i32) -> ()
-///    scf.yield
-///  }
-///  return %0 : i32
-struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
-  using OpRewritePattern<WhileOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(WhileOp op,
-                                PatternRewriter &rewriter) const override {
-    auto term = op.getConditionOp();
-    auto afterArgs = op.getAfterArguments();
-    auto termArgs = term.getArgs();
-
-    // Collect results mapping, new terminator args and new result types.
-    SmallVector<unsigned> newResultsIndices;
-    SmallVector<Type> newResultTypes;
-    SmallVector<Value> newTermArgs;
-    SmallVector<Location> newArgLocs;
-    bool needUpdate = false;
-    for (const auto &it :
-         llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
-      auto i = static_cast<unsigned>(it.index());
-      Value result = std::get<0>(it.value());
-      Value afterArg = std::get<1>(it.value());
-      Value termArg = std::get<2>(it.value());
-      if (result.use_empty() && afterArg.use_empty()) {
-        needUpdate = true;
-      } else {
-        newResultsIndices.emplace_back(i);
-        newTermArgs.emplace_back(termArg);
-        newResultTypes.emplace_back(result.getType());
-        newArgLocs.emplace_back(result.getLoc());
-      }
-    }
-
-    if (!needUpdate)
-      return failure();
-
-    {
-      OpBuilder::InsertionGuard g(rewriter);
-      rewriter.setInsertionPoint(term);
-      rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
-                                               newTermArgs);
-    }
-
-    auto newWhile =
-        WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits());
-
-    Block &newAfterBlock = *rewriter.createBlock(
-        &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
-
-    // Build new results list and new after block args (unused entries will be
-    // null).
-    SmallVector<Value> newResults(op.getNumResults());
-    SmallVector<Value> newAfterBlockArgs(op.getNumResults());
-    for (const auto &it : llvm::enumerate(newResultsIndices)) {
-      newResults[it.value()] = newWhile.getResult(it.index());
-      newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
-    }
-
-    rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
-                                newWhile.getBefore().begin());
-
-    Block &afterBlock = *op.getAfterBody();
-    rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
-
-    rewriter.replaceOp(op, newResults);
-    return success();
-  }
-};
-
 /// Replace operations equivalent to the condition in the do block with true,
 /// since otherwise the block would not be evaluated.
 ///
@@ -4407,65 +3987,6 @@ struct WhileCmpCond : public 
OpRewritePattern<scf::WhileOp> {
   }
 };
 
-/// Remove unused init/yield args.
-struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
-  using OpRewritePattern<WhileOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(WhileOp op,
-                                PatternRewriter &rewriter) const override {
-
-    if (!llvm::any_of(op.getBeforeArguments(),
-                      [](Value arg) { return arg.use_empty(); }))
-      return rewriter.notifyMatchFailure(op, "No args to remove");
-
-    YieldOp yield = op.getYieldOp();
-
-    // Collect results mapping, new terminator args and new result types.
-    SmallVector<Value> newYields;
-    SmallVector<Value> newInits;
-    llvm::BitVector argsToErase;
-
-    size_t argsCount = op.getBeforeArguments().size();
-    newYields.reserve(argsCount);
-    newInits.reserve(argsCount);
-    argsToErase.reserve(argsCount);
-    for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
-             op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
-      if (beforeArg.use_empty()) {
-        argsToErase.push_back(true);
-      } else {
-        argsToErase.push_back(false);
-        newYields.emplace_back(yieldValue);
-        newInits.emplace_back(initValue);
-      }
-    }
-
-    Block &beforeBlock = *op.getBeforeBody();
-    Block &afterBlock = *op.getAfterBody();
-
-    beforeBlock.eraseArguments(argsToErase);
-
-    Location loc = op.getLoc();
-    auto newWhileOp =
-        WhileOp::create(rewriter, loc, op.getResultTypes(), newInits,
-                        /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
-    Block &newBeforeBlock = *newWhileOp.getBeforeBody();
-    Block &newAfterBlock = *newWhileOp.getAfterBody();
-
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(yield);
-    rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
-
-    rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
-                         newBeforeBlock.getArguments());
-    rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
-                         newAfterBlock.getArguments());
-
-    rewriter.replaceOp(op, newWhileOp.getResults());
-    return success();
-  }
-};
-
 /// Remove duplicated ConditionOp args.
 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -4618,11 +4139,8 @@ struct WhileOpAlignBeforeArgs : public 
OpRewritePattern<WhileOp> {
 
 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                           MLIRContext *context) {
-  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
-              RemoveLoopInvariantValueYielded, WhileConditionTruth,
-              WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
-              WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
-      context);
+  results.add<WhileConditionTruth, WhileCmpCond, WhileRemoveDuplicatedResults,
+              WhileOpAlignBeforeArgs, WhileMoveIfDown>(context);
 }
 
 
//===----------------------------------------------------------------------===//
@@ -4797,59 +4315,9 @@ struct FoldConstantCase : 
OpRewritePattern<scf::IndexSwitchOp> {
   }
 };
 
-/// Canonicalization patterns that folds away dead results of
-/// "scf.index_switch" ops.
-struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
-  using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(IndexSwitchOp op,
-                                PatternRewriter &rewriter) const override {
-    // Find dead results.
-    BitVector deadResults(op.getNumResults(), false);
-    SmallVector<Type> newResultTypes;
-    for (auto [idx, result] : llvm::enumerate(op.getResults())) {
-      if (!result.use_empty()) {
-        newResultTypes.push_back(result.getType());
-      } else {
-        deadResults[idx] = true;
-      }
-    }
-    if (!deadResults.any())
-      return rewriter.notifyMatchFailure(op, "no dead results to fold");
-
-    // Create new op without dead results and inline case regions.
-    auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
-                                       op.getArg(), op.getCases(),
-                                       op.getCaseRegions().size());
-    auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
-      rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
-      // Remove respective operands from yield op.
-      Operation *terminator = newRegion.front().getTerminator();
-      assert(isa<YieldOp>(terminator) && "expected yield op");
-      rewriter.modifyOpInPlace(
-          terminator, [&]() { terminator->eraseOperands(deadResults); });
-    };
-    for (auto [oldRegion, newRegion] :
-         llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
-      inlineCaseRegion(oldRegion, newRegion);
-    inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
-
-    // Replace op with new op.
-    SmallVector<Value> newResults(op.getNumResults(), Value());
-    unsigned nextNewResult = 0;
-    for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
-      if (deadResults[idx])
-        continue;
-      newResults[idx] = newOp.getResult(nextNewResult++);
-    }
-    rewriter.replaceOp(op, newResults);
-    return success();
-  }
-};
-
 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
+  results.add<FoldConstantCase>(context);
 }
 
 
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp 
b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index d393ddb8d8336..ed94205d32f19 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -521,6 +521,45 @@ void 
RegionBranchOpInterface::getSuccessorOperandInputMapping(
   }
 }
 
+void RegionBranchOpInterface::getSuccessorInputOperandMapping(
+    RegionBranchInverseSuccessorMapping &mapping) {
+  RegionBranchSuccessorMapping operandToInputs;
+  getSuccessorOperandInputMapping(operandToInputs);
+  for (const auto &[operand, inputs] : operandToInputs) {
+    for (Value input : inputs)
+      mapping[input].push_back(operand);
+  }
+}
+
+DenseSet<Value>
+RegionBranchOpInterface::computePossibleValuesOfSuccessorInput(Value value) {
+  RegionBranchInverseSuccessorMapping inputToOperands;
+  getSuccessorInputOperandMapping(inputToOperands);
+
+  DenseSet<Value> possibleValues;
+  DenseSet<Value> visited;
+  SmallVector<Value> worklist;
+
+  // Starting with the given value, trace back all predecessor values (i.e.,
+  // preceding successor operands) and add them to the set of possible values.
+  // If the successor operand is again a successor input, do not add it to
+  // result set, but instead continue the traversal.
+  worklist.push_back(value);
+  while (!worklist.empty()) {
+    Value next = worklist.pop_back_val();
+    auto it = inputToOperands.find(next);
+    if (it == inputToOperands.end()) {
+      possibleValues.insert(next);
+      continue;
+    }
+    for (OpOperand *operand : it->second)
+      if (visited.insert(operand->get()).second)
+        worklist.push_back(operand->get());
+  }
+
+  return possibleValues;
+}
+
 SmallVector<RegionBranchPoint>
 RegionBranchOpInterface::getAllRegionBranchPoints() {
   SmallVector<RegionBranchPoint> branchPoints;
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir 
b/mlir/test/Dialect/SCF/canonicalize.mlir
index 984ea10f7e540..0420d1c018d76 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1736,11 +1736,11 @@ module {
 
 // Test case with multiple scf.yield ops with at least one different operand, 
then no change.
 
-// CHECK:           %[[VAL_3:.*]]:2 = scf.execute_region -> (memref<1x60xui8>, 
memref<1x120xui8>) no_inline {
+// CHECK:           %[[VAL_3:.*]] = scf.execute_region -> memref<1x120xui8> 
no_inline {
 // CHECK:           ^bb1:
-// CHECK:             scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, 
memref<1x120xui8>
+// CHECK:             scf.yield %{{.*}} : memref<1x120xui8>
 // CHECK:           ^bb2:
-// CHECK:             scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, 
memref<1x120xui8>
+// CHECK:             scf.yield  %{{.*}} : memref<1x120xui8>
 // CHECK:           }
 
 module {
@@ -2214,16 +2214,14 @@ func.func @iter_args_cycles_non_cycle_start(%lb : 
index, %ub : index, %step : in
 //  CHECK-SAME:     %[[arg0:.*]]: index
 //   CHECK-DAG:   %[[c10:.*]] = arith.constant 10
 //   CHECK-DAG:   %[[c11:.*]] = arith.constant 11
-//       CHECK:   %[[switch:.*]] = scf.index_switch %[[arg0]] -> index
+//       CHECK:   scf.index_switch %[[arg0]]
 //       CHECK:   case 1 {
 //       CHECK:     memref.store %[[c10]]
-//       CHECK:     scf.yield %[[arg0]] : index
 //       CHECK:   } 
 //       CHECK:   default {
 //       CHECK:     memref.store %[[c11]]
-//       CHECK:     scf.yield %[[arg0]] : index
 //       CHECK:   }
-//       CHECK:   return %[[switch]]
+//       CHECK:   return %[[arg0]]
 func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> 
index {
   %non_live, %live = scf.index_switch %arg0 -> i32, index
   case 1 {

>From 447cadee988ede17eda4df12bc316c8b7e688807 Mon Sep 17 00:00:00 2001
From: Matthias Springer <[email protected]>
Date: Wed, 31 Dec 2025 16:00:05 +0000
Subject: [PATCH 2/2] fix some tests

---
 .../mlir/Dialect/SparseTensor/IR/SparseTensorOps.td        | 7 ++++---
 mlir/test/Dialect/Vector/vector-warp-distribute.mlir       | 6 +++---
 2 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td 
b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index a61d90a0c39b1..f41b3694d9c79 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1304,9 +1304,10 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", 
[Pure, SameOperandsAndResu
   let hasVerifier = 1;
 }
 
-def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
-    ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
-                 "ForeachOp", "IterateOp", "CoIterateOp"]>]> {
+def SparseTensor_YieldOp : SparseTensor_Op<"yield",
+    [Pure, Terminator, ReturnLike,
+     ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
+                  "ForeachOp", "IterateOp", "CoIterateOp"]>]> {
   let summary = "Yield from sparse_tensor set-like operations";
   let description = [{
       Yields a value from within a `binary`, `unary`, `reduce`,
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir 
b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 135db02d543ef..18fb6852f6875 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1330,11 +1330,11 @@ func.func @vector_insert_1d_broadcast(%laneid: index, 
%pos: index) -> (vector<96
 // -----
 
 // CHECK-PROP-LABEL: func @vector_insert_0d(
-//       CHECK-PROP:   %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> 
(vector<f32>, f32)
+//       CHECK-PROP:   %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (f32)
 //       CHECK-PROP:     %[[VEC:.*]] = "some_def"
 //       CHECK-PROP:     %[[VAL:.*]] = "another_def"
-//       CHECK-PROP:     gpu.yield %[[VEC]], %[[VAL]]
-//       CHECK-PROP:   vector.broadcast %[[W]]#1 : f32 to vector<f32>
+//       CHECK-PROP:     gpu.yield %[[VAL]]
+//       CHECK-PROP:   vector.broadcast %[[W]] : f32 to vector<f32>
 func.func @vector_insert_0d(%laneid: index) -> (vector<f32>) {
   %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<f32>) {
     %0 = "some_def"() : () -> (vector<f32>)

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to