https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/173560
>From da738b44231c3e2b2ede42380057774e12f0d9b3 Mon Sep 17 00:00:00 2001 From: Matthias Springer <[email protected]> Date: Thu, 25 Dec 2025 12:55:26 +0000 Subject: [PATCH 1/2] [mlir][SCF] Fold unused `index_switch` results --- mlir/lib/Dialect/SCF/IR/SCF.cpp | 52 ++++++++++++++++++++++++- mlir/test/Dialect/SCF/canonicalize.mlir | 31 +++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 4a6b8aa7b1125..46d09abd89d69 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -4797,9 +4797,59 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> { } }; +/// Canonicalization patterns that folds away dead results of +/// "scf.index_switch" ops. +struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> { + using OpRewritePattern<IndexSwitchOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(IndexSwitchOp op, + PatternRewriter &rewriter) const override { + // Find dead results. + BitVector deadResults(op.getNumResults(), false); + SmallVector<Type> newResultTypes; + for (auto [idx, result] : llvm::enumerate(op.getResults())) { + if (!result.use_empty()) { + newResultTypes.push_back(result.getType()); + } else { + deadResults[idx] = true; + } + } + if (!deadResults.any()) + return rewriter.notifyMatchFailure(op, "no dead results to fold"); + + // Create new op without dead results and inline case regions. + auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes, + op.getArg(), op.getCases(), + op.getCaseRegions().size()); + auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) { + rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin()); + // Remove respective operands from yield op. + Operation *terminator = newRegion.front().getTerminator(); + assert(isa<YieldOp>(terminator) && "expected yield op"); + rewriter.modifyOpInPlace( + terminator, [&]() { terminator->eraseOperands(deadResults); }); + }; + for (auto [oldRegion, newRegion] : + llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions())) + inlineCaseRegion(oldRegion, newRegion); + inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion()); + + // Replace op with new op. + SmallVector<Value> newResults(op.getNumResults(), Value()); + unsigned nextNewResult = 0; + for (unsigned idx = 0; idx < op.getNumResults(); ++idx) { + if (deadResults[idx]) + continue; + newResults[idx] = newOp.getResult(nextNewResult++); + } + rewriter.replaceOp(op, newResults); + return success(); + } +}; + void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<FoldConstantCase>(context); + results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 37851710ef010..984ea10f7e540 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -2207,3 +2207,34 @@ func.func @iter_args_cycles_non_cycle_start(%lb : index, %ub : index, %step : in } return %res#0, %res#1, %res#2 : i32, i32, i32 } + +// ----- + +// CHECK-LABEL: func @dead_index_switch_result( +// CHECK-SAME: %[[arg0:.*]]: index +// CHECK-DAG: %[[c10:.*]] = arith.constant 10 +// CHECK-DAG: %[[c11:.*]] = arith.constant 11 +// CHECK: %[[switch:.*]] = scf.index_switch %[[arg0]] -> index +// CHECK: case 1 { +// CHECK: memref.store %[[c10]] +// CHECK: scf.yield %[[arg0]] : index +// CHECK: } +// CHECK: default { +// CHECK: memref.store %[[c11]] +// CHECK: scf.yield %[[arg0]] : index +// CHECK: } +// CHECK: return %[[switch]] +func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index { + %non_live, %live = scf.index_switch %arg0 -> i32, index + case 1 { + %c10 = arith.constant 10 : i32 + memref.store %c10, %arg1[] : memref<i32> + scf.yield %c10, %arg0 : i32, index + } + default { + %c11 = arith.constant 11 : i32 + memref.store %c11, %arg1[] : memref<i32> + scf.yield %c11, %arg0 : i32, index + } + return %live : index +} >From 9bef67465ba47702c8a24b59fedde3d1aee8f9ad Mon Sep 17 00:00:00 2001 From: Matthias Springer <[email protected]> Date: Sun, 28 Dec 2025 19:10:20 +0100 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: Mehdi Amini <[email protected]> Co-authored-by: lonely eagle <[email protected]> --- mlir/lib/Dialect/SCF/IR/SCF.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 46d09abd89d69..178e344b8963e 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -4814,7 +4814,7 @@ struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> { deadResults[idx] = true; } } - if (!deadResults.any()) + if (newResultTypes.size() != op.getNumResults()) return rewriter.notifyMatchFailure(op, "no dead results to fold"); // Create new op without dead results and inline case regions. @@ -4837,7 +4837,7 @@ struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> { // Replace op with new op. SmallVector<Value> newResults(op.getNumResults(), Value()); unsigned nextNewResult = 0; - for (unsigned idx = 0; idx < op.getNumResults(); ++idx) { + for (unsigned idx = 0, e = op.getNumResults(); idx < e; ++idx) { if (deadResults[idx]) continue; newResults[idx] = newOp.getResult(nextNewResult++); _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
