https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/173505
>From 8c30a791bf275f7b8e288fcdcada94773881c5a6 Mon Sep 17 00:00:00 2001 From: Matthias Springer <[email protected]> Date: Wed, 24 Dec 2025 13:26:00 +0000 Subject: [PATCH] tmp commit simple test working draft: do not erase IR, just replace uses --- mlir/include/mlir/Transforms/Passes.h | 1 + mlir/include/mlir/Transforms/Passes.td | 10 + mlir/lib/Transforms/RemoveDeadValues.cpp | 484 +++++++------------ mlir/test/Transforms/remove-dead-values.mlir | 155 ++++-- 4 files changed, 287 insertions(+), 363 deletions(-) diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 724da009e70f1..9983944d374c5 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -42,6 +42,7 @@ class GreedyRewriteConfig; #define GEN_PASS_DECL_MEM2REG #define GEN_PASS_DECL_PRINTIRPASS #define GEN_PASS_DECL_PRINTOPSTATS +#define GEN_PASS_DECL_REMOVEDEADVALUES #define GEN_PASS_DECL_SCCP #define GEN_PASS_DECL_SROA #define GEN_PASS_DECL_STRIPDEBUGINFO diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 55addfdb693e4..fc2d60d198cd6 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -246,7 +246,17 @@ def RemoveDeadValues : Pass<"remove-dead-values"> { do = square_and_double_of_y(5) print(do) ``` + + Note: If `canonicalize` is set to "false", this pass does not remove any + block arguments / op results from ops that implement the + RegionBranchOpInterface. Instead, it just sets dead operands to + "ub.poison". }]; + + let options = [ + Option<"canonicalize", "canonicalize", "bool", /*default=*/"true", + "Canonicalize region branch ops">, + ]; let constructor = "mlir::createRemoveDeadValuesPass()"; let dependentDialects = ["ub::UBDialect"]; } diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index fc2c2acf8afd3..17128c5cdf898 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -94,8 +94,11 @@ struct ResultsToCleanup { struct OperandsToCleanup { Operation *op; BitVector nonLive; - Operation *callee = - nullptr; // Optional: For CallOpInterface ops, stores the callee function + // Optional: For CallOpInterface ops, stores the callee function. + Operation *callee = nullptr; + // Determines whether the operand should be replaced with a ub.poison result + // or erased entirely. + bool replaceWithPoison = false; }; struct BlockArgsToCleanup { @@ -199,27 +202,6 @@ static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range, } } -/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i] -/// is 1. -static void dropUsesAndEraseResults(Operation *op, BitVector toErase) { - assert(op->getNumResults() == toErase.size() && - "expected the number of results in `op` and the size of `toErase` to " - "be the same"); - for (auto idx : toErase.set_bits()) - op->getResult(idx).dropAllUses(); - IRRewriter rewriter(op); - rewriter.eraseOpResults(op, toErase); -} - -/// Convert a list of `Operand`s to a list of `OpOperand`s. -static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) { - OpOperand *values = operands.getBase(); - SmallVector<OpOperand *> opOperands; - for (unsigned i = 0, e = operands.size(); i < e; i++) - opOperands.push_back(&values[i]); - return opOperands; -} - /// Process a simple operation `op` using the liveness analysis `la`. /// If the operation has no memory effects and none of its results are live: /// 1. Add the operation to a list for future removal, and @@ -379,30 +361,20 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, /// /// Scenario 1: If the operation has no memory effects and none of its results /// are live: -/// (1') Enqueue all its uses for deletion. -/// (2') Enqueue the branch itself for deletion. +/// 1.1. Enqueue all its uses for deletion. +/// 1.2. Enqueue the branch itself for deletion. /// /// Scenario 2: Otherwise: -/// (1) Collect its unnecessary operands (operands forwarded to unnecessary -/// results or arguments). -/// (2) Process each of its regions. -/// (3) Collect the uses of its unnecessary results (results forwarded from -/// unnecessary operands -/// or terminator operands). -/// (4) Add these results to the deletion list. +/// 2.1. Collect block arguments and op results that we would like to keep, +/// based on their liveness. +/// 2.2. Find all operands that are forwarded to only dead region successor +/// inputs. I.e., forwarded to block arguments / op results that we do +/// not want to keep. +/// 2.3. Enqueue all such operands for replacement with ub.poison. /// -/// Processing a region includes: -/// (a) Collecting the uses of its unnecessary arguments (arguments forwarded -/// from unnecessary operands -/// or terminator operands). -/// (b) Collecting these unnecessary arguments. -/// (c) Collecting its unnecessary terminator operands (terminator operands -/// forwarded to unnecessary results -/// or arguments). -/// -/// Value Flow Note: In this operation, values flow as follows: -/// - From operands and terminator operands (successor operands) -/// - To arguments and results (successor inputs). +/// Note: In scenario 2, block arguments and op results are not removed. +/// However, the IR is simplified such that canonicalization patterns can +/// remove them later. static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, @@ -416,282 +388,76 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, // case, a non-forwarded operand of `regionBranchOp` could be live/non-live. // It could never be live because of this op but its liveness could have been // attributed to something else. - // Do (1') and (2'). if (isMemoryEffectFree(regionBranchOp.getOperation()) && !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) { cl.operations.push_back(regionBranchOp.getOperation()); return; } - // Mark live results of `regionBranchOp` in `liveResults`. - auto markLiveResults = [&](BitVector &liveResults) { - liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la); - }; - - // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`. - auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) { - for (Region ®ion : regionBranchOp->getRegions()) { - if (region.empty()) - continue; - SmallVector<Value> arguments(region.front().getArguments()); - BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la); - liveArgs[®ion] = regionLiveArgs; - } - }; - - // Return the successors of `region` if the latter is not null. Else return - // the successors of `regionBranchOp`. - auto getSuccessors = [&](RegionBranchPoint point) { - SmallVector<RegionSuccessor> successors; - regionBranchOp.getSuccessorRegions(point, successors); - return successors; - }; - - // Return the operands of `terminator` that are forwarded to `successor` if - // the former is not null. Else return the operands of `regionBranchOp` - // forwarded to `successor`. - auto getForwardedOpOperands = [&](RegionBranchPoint src, - const RegionSuccessor &successor) { - SmallVector<OpOperand *> opOperands = operandsToOpOperands( - regionBranchOp.getSuccessorOperands(src, successor)); - return opOperands; - }; - - // Mark the non-forwarded operands of `regionBranchOp` in - // `nonForwardedOperands`. - auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) { - nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true); - for (const RegionSuccessor &successor : - getSuccessors(RegionBranchPoint::parent())) { - for (OpOperand *opOperand : - getForwardedOpOperands(RegionBranchPoint::parent(), successor)) - nonForwardedOperands.reset(opOperand->getOperandNumber()); - } - }; - - // Mark the non-forwarded terminator operands of the various regions of - // `regionBranchOp` in `nonForwardedRets`. - auto markNonForwardedReturnValues = - [&](DenseMap<Operation *, BitVector> &nonForwardedRets) { - for (Region ®ion : regionBranchOp->getRegions()) { - if (region.empty()) - continue; - // TODO: this isn't correct in face of multiple terminators. - auto terminator = cast<RegionBranchTerminatorOpInterface>( - region.front().getTerminator()); - nonForwardedRets[terminator] = - BitVector(terminator->getNumOperands(), true); - for (const RegionSuccessor &successor : getSuccessors(terminator)) { - for (OpOperand *opOperand : getForwardedOpOperands( - RegionBranchPoint(terminator), successor)) - nonForwardedRets[terminator].reset(opOperand->getOperandNumber()); - } - } - }; - - // Update `valuesToKeep` (which is expected to correspond to operands or - // terminator operands) based on `resultsToKeep` and `argsToKeep`, given - // `region`. When `valuesToKeep` correspond to operands, `region` is null. - // Else, `region` is the parent region of the terminator. - auto updateOperandsOrTerminatorOperandsToKeep = - [&](BitVector &valuesToKeep, BitVector &resultsToKeep, - DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) { - Operation *terminator = - region ? region->front().getTerminator() : nullptr; - RegionBranchPoint point = - terminator - ? RegionBranchPoint( - cast<RegionBranchTerminatorOpInterface>(terminator)) - : RegionBranchPoint::parent(); - - for (const RegionSuccessor &successor : getSuccessors(point)) { - Region *successorRegion = successor.getSuccessor(); - for (auto [opOperand, input] : - llvm::zip(getForwardedOpOperands(point, successor), - successor.getSuccessorInputs())) { - size_t operandNum = opOperand->getOperandNumber(); - bool updateBasedOn = - successorRegion - ? argsToKeep[successorRegion] - [cast<BlockArgument>(input).getArgNumber()] - : resultsToKeep[cast<OpResult>(input).getResultNumber()]; - valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn; - } - } - }; - - // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and - // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a - // value is modified, else, false. - auto recomputeResultsAndArgsToKeep = - [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep, - BitVector &operandsToKeep, - DenseMap<Operation *, BitVector> &terminatorOperandsToKeep, - bool &resultsOrArgsToKeepChanged) { - resultsOrArgsToKeepChanged = false; - - // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`. - for (const RegionSuccessor &successor : - getSuccessors(RegionBranchPoint::parent())) { - Region *successorRegion = successor.getSuccessor(); - for (auto [opOperand, input] : - llvm::zip(getForwardedOpOperands(RegionBranchPoint::parent(), - successor), - successor.getSuccessorInputs())) { - bool recomputeBasedOn = - operandsToKeep[opOperand->getOperandNumber()]; - bool toRecompute = - successorRegion - ? argsToKeep[successorRegion] - [cast<BlockArgument>(input).getArgNumber()] - : resultsToKeep[cast<OpResult>(input).getResultNumber()]; - if (!toRecompute && recomputeBasedOn) - resultsOrArgsToKeepChanged = true; - if (successorRegion) { - argsToKeep[successorRegion][cast<BlockArgument>(input) - .getArgNumber()] = - argsToKeep[successorRegion] - [cast<BlockArgument>(input).getArgNumber()] | - recomputeBasedOn; - } else { - resultsToKeep[cast<OpResult>(input).getResultNumber()] = - resultsToKeep[cast<OpResult>(input).getResultNumber()] | - recomputeBasedOn; - } - } - } - - // Recompute `resultsToKeep` and `argsToKeep` based on - // `terminatorOperandsToKeep`. - for (Region ®ion : regionBranchOp->getRegions()) { - if (region.empty()) - continue; - auto terminator = cast<RegionBranchTerminatorOpInterface>( - region.front().getTerminator()); - for (const RegionSuccessor &successor : getSuccessors(terminator)) { - Region *successorRegion = successor.getSuccessor(); - for (auto [opOperand, input] : - llvm::zip(getForwardedOpOperands(RegionBranchPoint(terminator), - successor), - successor.getSuccessorInputs())) { - bool recomputeBasedOn = - terminatorOperandsToKeep[region.back().getTerminator()] - [opOperand->getOperandNumber()]; - bool toRecompute = - successorRegion - ? argsToKeep[successorRegion] - [cast<BlockArgument>(input).getArgNumber()] - : resultsToKeep[cast<OpResult>(input).getResultNumber()]; - if (!toRecompute && recomputeBasedOn) - resultsOrArgsToKeepChanged = true; - if (successorRegion) { - argsToKeep[successorRegion][cast<BlockArgument>(input) - .getArgNumber()] = - argsToKeep[successorRegion] - [cast<BlockArgument>(input).getArgNumber()] | - recomputeBasedOn; - } else { - resultsToKeep[cast<OpResult>(input).getResultNumber()] = - resultsToKeep[cast<OpResult>(input).getResultNumber()] | - recomputeBasedOn; - } - } - } - } - }; - - // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`, - // `operandsToKeep`, and `terminatorOperandsToKeep`. - auto markValuesToKeep = - [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep, - BitVector &operandsToKeep, - DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) { - bool resultsOrArgsToKeepChanged = true; - // We keep updating and recomputing the values until we reach a point - // where they stop changing. - while (resultsOrArgsToKeepChanged) { - // Update the operands that need to be kept. - updateOperandsOrTerminatorOperandsToKeep(operandsToKeep, - resultsToKeep, argsToKeep); - - // Update the terminator operands that need to be kept. - for (Region ®ion : regionBranchOp->getRegions()) { - if (region.empty()) - continue; - updateOperandsOrTerminatorOperandsToKeep( - terminatorOperandsToKeep[region.back().getTerminator()], - resultsToKeep, argsToKeep, ®ion); - } - - // Recompute the results and arguments that need to be kept. - recomputeResultsAndArgsToKeep( - resultsToKeep, argsToKeep, operandsToKeep, - terminatorOperandsToKeep, resultsOrArgsToKeepChanged); - } - }; - - // Scenario 2. - // At this point, we know that every non-forwarded operand of `regionBranchOp` - // is live. - - // Stores the results of `regionBranchOp` that we want to keep. - BitVector resultsToKeep; - // Stores the mapping from regions of `regionBranchOp` to their arguments that - // we want to keep. - DenseMap<Region *, BitVector> argsToKeep; - // Stores the operands of `regionBranchOp` that we want to keep. - BitVector operandsToKeep; - // Stores the mapping from region terminators in `regionBranchOp` to their - // operands that we want to keep. - DenseMap<Operation *, BitVector> terminatorOperandsToKeep; - - // Initializing the above variables... - - // The live results of `regionBranchOp` definitely need to be kept. - markLiveResults(resultsToKeep); - // Similarly, the live arguments of the regions in `regionBranchOp` definitely - // need to be kept. - markLiveArgs(argsToKeep); - // The non-forwarded operands of `regionBranchOp` definitely need to be kept. - // A live forwarded operand can be removed but no non-forwarded operand can be - // removed since it "controls" the flow of data in this control flow op. - markNonForwardedOperands(operandsToKeep); - // Similarly, the non-forwarded terminator operands of the regions in - // `regionBranchOp` definitely need to be kept. - markNonForwardedReturnValues(terminatorOperandsToKeep); - - // Mark the values (results, arguments, operands, and terminator operands) - // that we want to keep. - markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep, - terminatorOperandsToKeep); - - // Do (1). - cl.operands.push_back({regionBranchOp, operandsToKeep.flip()}); - - // Do (2.a) and (2.b). + // Compute values that are alive. + llvm::SmallDenseSet<Value> valuesToKeep; + for (Value result : regionBranchOp->getResults()) { + if (hasLive(result, nonLiveSet, la)) + valuesToKeep.insert(result); + } for (Region ®ion : regionBranchOp->getRegions()) { if (region.empty()) continue; - BitVector argsToRemove = argsToKeep[®ion].flip(); - cl.blocks.push_back({®ion.front(), argsToRemove}); - collectNonLiveValues(nonLiveSet, region.front().getArguments(), - argsToRemove); + for (Value arg : region.front().getArguments()) { + if (hasLive(arg, nonLiveSet, la)) + valuesToKeep.insert(arg); + } } - // Do (2.c). - for (Region ®ion : regionBranchOp->getRegions()) { - if (region.empty()) + // Mapping from operands to forwarded successor inputs. An operand can be + // forwarded to multiple successors. + // + // Example: + // + // %0 = scf.while : () -> i32 { + // scf.condition(...) %forwarded_value : i32 + // } do { + // ^bb0(%arg0: i32): + // scf.yield + // } + // // No uses of %0. + // + // In the above example, %forwarded_value is forwarded to %arg0 and %0. Both + // %arg0 and %0 are dead, so %forwarded_value can be replaced with a + // ub.poison result. + // + // operandToSuccessorInputs[%forwarded_value] = {%arg0, %0} + // + RegionBranchSuccessorMapping operandToSuccessorInputs; + regionBranchOp.getSuccessorOperandInputMapping(operandToSuccessorInputs); + + DenseMap<Operation *, BitVector> deadOperandsPerOp; + for (auto [opOperand, successorInputs] : operandToSuccessorInputs) { + // If one of the successor inputs is live, the respective operand must be + // kept. + bool anyAlive = llvm::any_of(successorInputs, [&](Value input) { + return valuesToKeep.contains(input); + }); + if (anyAlive) continue; - Operation *terminator = region.front().getTerminator(); - cl.operands.push_back( - {terminator, terminatorOperandsToKeep[terminator].flip()}); + + // All successor inputs are dead: ub.poison can be passed as operand. + // Create an entry in `deadOperandsPerOp` (initialized to "false", i.e., + // no "dead" op operands) if it's the first time that we are seeing an op + // operand for this op. Otherwise, just take the existing bit vector from + // the map. + BitVector &deadOperands = + deadOperandsPerOp + .try_emplace(opOperand->getOwner(), + opOperand->getOwner()->getNumOperands(), false) + .first->second; + deadOperands.set(opOperand->getOperandNumber()); } - // Do (3) and (4). - BitVector resultsToRemove = resultsToKeep.flip(); - collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(), - resultsToRemove); - cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove}); + for (auto [op, deadOperands] : deadOperandsPerOp) { + cl.operands.push_back( + {op, deadOperands, nullptr, /*replaceWithPoison=*/true}); + } } /// Steps to process a `BranchOpInterface` operation: @@ -751,11 +517,44 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, } } +/// Create ub.poison ops for the given values. If a value has no uses, return +/// an "empty" value. +static SmallVector<Value> createPoisonedValues(OpBuilder &b, + ValueRange values) { + return llvm::map_to_vector(values, [&](Value value) { + if (value.use_empty()) + return Value(); + return ub::PoisonOp::create(b, value.getLoc(), value.getType()).getResult(); + }); +} + +namespace { +/// A listener that keeps track of ub.poison ops. +struct TrackingListener : public RewriterBase::Listener { + void notifyOperationErased(Operation *op) override { + if (auto poisonOp = dyn_cast<ub::PoisonOp>(op)) + poisonOps.erase(poisonOp); + } + void notifyOperationInserted(Operation *op, + OpBuilder::InsertPoint previous) override { + if (auto poisonOp = dyn_cast<ub::PoisonOp>(op)) + poisonOps.insert(poisonOp); + } + DenseSet<ub::PoisonOp> poisonOps; +}; +} // namespace + /// Removes dead values collected in RDVFinalCleanupList. /// To be run once when all dead values have been collected. -static void cleanUpDeadVals(RDVFinalCleanupList &list) { +static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) { LDBG() << "Starting cleanup of dead values..."; + // New ub.poison ops may be inserted during cleanup. Some of these ops may no + // longer be needed after the cleanup. A tracking listener keeps track of all + // new ub.poison ops, so that they can be removed again after the cleanup. + TrackingListener listener; + IRRewriter rewriter(ctx, &listener); + // 1. Blocks, We must remove the block arguments and successor operands before // deleting the operation, as they may reside in the region operation. LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists"; @@ -773,10 +572,12 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { }); // Note: Iterate from the end to make sure that that indices of not yet // processes arguments do not change. + rewriter.setInsertionPointToStart(b.b); for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) { if (!b.nonLiveArgs[i]) continue; - b.b->getArgument(i).dropAllUses(); + b.b->getArgument(i).replaceAllUsesWith( + createPoisonedValues(rewriter, b.b->getArgument(i)).front()); b.b->eraseArgument(i); } } @@ -822,12 +623,18 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { llvm::interleaveComma(f.nonLiveRets.set_bits(), os); os << "]"; }); - // Drop all uses of the dead arguments. - for (auto deadIdx : f.nonLiveArgs.set_bits()) - f.funcOp.getArgument(deadIdx).dropAllUses(); // Some functions may not allow erasing arguments or results. These calls // return failure in such cases without modifying the function, so it's okay // to proceed. + bool hasBody = !f.funcOp.getFunctionBody().empty(); + if (hasBody) { + rewriter.setInsertionPointToStart(&f.funcOp.getFunctionBody().front()); + for (auto deadIdx : f.nonLiveArgs.set_bits()) { + f.funcOp.getArgument(deadIdx).replaceAllUsesWith( + createPoisonedValues(rewriter, f.funcOp.getArgument(deadIdx)) + .front()); + } + } if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) { // Record only if we actually erased something. if (f.nonLiveArgs.any()) @@ -881,7 +688,16 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { << OpWithFlags(o.op, OpPrintingFlags().skipRegions().printGenericOpForm()); }); - o.op->eraseOperands(o.nonLive); + if (o.replaceWithPoison) { + rewriter.setInsertionPoint(o.op); + for (auto deadIdx : o.nonLive.set_bits()) { + o.op->setOperand( + deadIdx, createPoisonedValues(rewriter, o.op->getOperand(deadIdx)) + .front()); + } + } else { + o.op->eraseOperands(o.nonLive); + } } } @@ -895,7 +711,12 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { << OpWithFlags(r.op, OpPrintingFlags().skipRegions().printGenericOpForm()); }); - dropUsesAndEraseResults(r.op, r.nonLive); + rewriter.setInsertionPoint(r.op); + for (auto deadIdx : r.nonLive.set_bits()) { + r.op->getResult(deadIdx).replaceAllUsesWith( + createPoisonedValues(rewriter, r.op->getResult(deadIdx)).front()); + } + rewriter.eraseOpResults(r.op, r.nonLive); } // 6. Operations @@ -904,13 +725,18 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { LDBG() << "Erasing operation: " << OpWithFlags(op, OpPrintingFlags().skipRegions().printGenericOpForm()); + rewriter.setInsertionPoint(op); if (op->hasTrait<OpTrait::IsTerminator>()) { // When erasing a terminator, insert an unreachable op in its place. - OpBuilder b(op); - ub::UnreachableOp::create(b, op->getLoc()); + ub::UnreachableOp::create(rewriter, op->getLoc()); } - op->dropAllUses(); - op->erase(); + rewriter.replaceOp(op, createPoisonedValues(rewriter, op->getResults())); + } + + // 7. Remove all dead poison ops. + for (ub::PoisonOp poisonOp : listener.poisonOps) { + if (poisonOp.use_empty()) + poisonOp.erase(); } LDBG() << "Finished cleanup of dead values"; @@ -951,7 +777,29 @@ void RemoveDeadValues::runOnOperation() { } }); - cleanUpDeadVals(finalCleanupList); + MLIRContext *context = module->getContext(); + cleanUpDeadVals(context, finalCleanupList); + + if (!canonicalize) + return; + + // Canonicalize all region branch ops. + SmallVector<Operation *> opsToCanonicalize; + module->walk([&](RegionBranchOpInterface regionBranchOp) { + opsToCanonicalize.push_back(regionBranchOp.getOperation()); + }); + // TODO: Apply only region branch op canonicalization patterns or find a + // better API to collect all canonicalization patterns. + RewritePatternSet owningPatterns(context); + for (auto *dialect : context->getLoadedDialects()) + dialect->getCanonicalizationPatterns(owningPatterns); + for (RegisteredOperationName op : context->getRegisteredOperations()) + op.getCanonicalizationPatterns(owningPatterns, context); + if (failed(applyOpPatternsGreedily(opsToCanonicalize, + std::move(owningPatterns)))) { + module->emitError("greedy pattern rewrite failed to converge"); + signalPassFailure(); + } } std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() { diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index bd730915c6dcd..2584573c8b4dc 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -remove-dead-values -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -remove-dead-values="canonicalize=0" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -remove-dead-values="canonicalize=1" -split-input-file | FileCheck %s --check-prefix=CHECK-CANONICALIZE // The IR is updated regardless of memref.global private constant // @@ -55,19 +56,20 @@ func.func @acceptable_ir_has_cleanable_loop_of_conditional_and_branch_op(%arg0: // Checking that iter_args are properly handled // +// CHECK-CANONICALIZE-LABEL: func @cleanable_loop_iter_args_value func.func @cleanable_loop_iter_args_value(%arg0: index) -> index { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c10 = arith.constant 10 : index %non_live = arith.constant 0 : index - // CHECK: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) { + // CHECK-CANONICALIZE: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) { %result, %result_non_live = scf.for %i = %c0 to %c10 step %c1 iter_args(%live_arg = %arg0, %non_live_arg = %non_live) -> (index, index) { - // CHECK: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index + // CHECK-CANONICALIZE: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index %new_live = arith.addi %live_arg, %i : index - // CHECK: scf.yield [[SUM:%.+]] + // CHECK-CANONICALIZE: scf.yield [[SUM:%.+]] scf.yield %new_live, %non_live_arg : index, index } - // CHECK: return [[RESULT]] : index + // CHECK-CANONICALIZE: return [[RESULT]] : index return %result : index } @@ -79,7 +81,8 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> index { #map = affine_map<(d0, d1, d2) -> (0, d1, d2)> #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> module { - func.func @main() { + // CHECK-LABEL: @dead_linalg_generic + func.func @dead_linalg_generic() { %cst_3 = arith.constant dense<54> : tensor<1x25x13xi32> %cst_7 = arith.constant dense<11> : tensor<1x25x13xi32> // CHECK-NOT: arith.constant @@ -229,18 +232,34 @@ func.func @main() -> (i32, i32) { // anywhere else. Thus, %arg7 is also not kept in the `scf.yield` op. // // Note that this cleanup cannot be done by the `canonicalize` pass. -// -// CHECK: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 { -// CHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) { -// CHECK-NEXT: %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]] -// CHECK-NEXT: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32 + +// CHECK-LABEL: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand( +// CHECK-SAME: %[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 { +// CHECK-NEXT: %[[p0:.*]] = ub.poison : i32 +// CHECK-NEXT: %[[while:.*]]:3 = scf.while (%{{.*}} = %[[p0]], %[[arg4:.*]] = %[[arg2]]) : (i32, i32) -> (i32, i32, i32) { +// CHECK-NEXT: %[[add1:.*]] = arith.addi %[[arg4]], %[[arg4]] : i32 +// CHECK-NEXT: %[[p1:.*]] = ub.poison : i32 +// CHECK-NEXT: scf.condition(%[[arg0]]) %[[add1]], %[[arg4]], %[[p1]] : i32, i32, i32 // CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32): -// CHECK-NEXT: %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]] -// CHECK-NEXT: scf.yield %[[live_1]] : i32 +// CHECK-NEXT: ^bb0(%{{.*}}: i32, %[[arg6:.*]]: i32, %{{.*}}: i32): +// CHECK-NEXT: %[[add2:.*]] = arith.addi %[[arg6]], %[[arg6]] : i32 +// CHECK-NEXT: %[[p2:.*]] = ub.poison : i32 +// CHECK-NEXT: scf.yield %[[p2]], %[[add2]] : i32, i32 // CHECK-NEXT: } -// CHECK-NEXT: return %[[live_and_non_live]]#0 +// CHECK-NEXT: return %[[while]]#0 : i32 // CHECK-NEXT: } + +// CHECK-CANONICALIZE: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 { +// CHECK-CANONICALIZE-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) { +// CHECK-CANONICALIZE-NEXT: %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]] +// CHECK-CANONICALIZE-NEXT: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32 +// CHECK-CANONICALIZE-NEXT: } do { +// CHECK-CANONICALIZE-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32): +// CHECK-CANONICALIZE-NEXT: %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]] +// CHECK-CANONICALIZE-NEXT: scf.yield %[[live_1]] : i32 +// CHECK-CANONICALIZE-NEXT: } +// CHECK-CANONICALIZE-NEXT: return %[[live_and_non_live]]#0 +// CHECK-CANONICALIZE-NEXT: } func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%arg0: i1, %arg1: i32, %arg2: i32) -> (i32) { %live, %non_live, %non_live_0 = scf.while (%arg3 = %arg1, %arg4 = %arg2) : (i32, i32) -> (i32, i32, i32) { %live_0 = arith.addi %arg4, %arg4 : i32 @@ -284,21 +303,21 @@ func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_o // // Note that this cleanup cannot be done by the `canonicalize` pass. // -// CHECK: func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 { -// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 -// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 -// CHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) { -// CHECK-NEXT: func.call @identity() : () -> () -// CHECK-NEXT: scf.condition(%[[arg2]]) %[[arg4]], %[[arg3]] : i32, i32 -// CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32): -// CHECK-NEXT: scf.yield %[[arg5]], %[[arg6]] : i32, i32 -// CHECK-NEXT: } -// CHECK-NEXT: return %[[live_and_non_live]]#0 : i32 -// CHECK-NEXT: } -// CHECK: func.func private @identity() { -// CHECK-NEXT: return -// CHECK-NEXT: } +// CHECK-CANONICALIZE: func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 { +// CHECK-CANONICALIZE-NEXT: %[[c0:.*]] = arith.constant 0 +// CHECK-CANONICALIZE-NEXT: %[[c1:.*]] = arith.constant 1 +// CHECK-CANONICALIZE-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) { +// CHECK-CANONICALIZE-NEXT: func.call @identity() : () -> () +// CHECK-CANONICALIZE-NEXT: scf.condition(%[[arg2]]) %[[arg3]], %[[arg4]] : i32, i32 +// CHECK-CANONICALIZE-NEXT: } do { +// CHECK-CANONICALIZE-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32): +// CHECK-CANONICALIZE-NEXT: scf.yield %[[arg6]], %[[arg5]] : i32, i32 +// CHECK-CANONICALIZE-NEXT: } +// CHECK-CANONICALIZE-NEXT: return %[[live_and_non_live]]#1 : i32 +// CHECK-CANONICALIZE-NEXT: } +// CHECK-CANONICALIZE: func.func private @identity() { +// CHECK-CANONICALIZE-NEXT: return +// CHECK-CANONICALIZE-NEXT: } func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%arg2: i1) -> (i32) { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 @@ -325,17 +344,17 @@ func.func private @identity(%arg1 : i32) -> (i32) { // // Note that this cleanup cannot be done by the `canonicalize` pass. // -// CHECK: func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) { -// CHECK-NEXT: scf.index_switch %[[arg0]] -// CHECK-NEXT: case 1 { -// CHECK-NEXT: %[[c10:.*]] = arith.constant 10 -// CHECK-NEXT: memref.store %[[c10]], %[[arg1]][] -// CHECK-NEXT: scf.yield -// CHECK-NEXT: } -// CHECK-NEXT: default { -// CHECK-NEXT: } -// CHECK-NEXT: return -// CHECK-NEXT: } +// CHECK-CANONICALIZE: func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) { +// CHECK-CANONICALIZE-NEXT: scf.index_switch %[[arg0]] +// CHECK-CANONICALIZE-NEXT: case 1 { +// CHECK-CANONICALIZE-NEXT: %[[c10:.*]] = arith.constant 10 +// CHECK-CANONICALIZE-NEXT: memref.store %[[c10]], %[[arg1]][] +// CHECK-CANONICALIZE: scf.yield +// CHECK-CANONICALIZE-NEXT: } +// CHECK-CANONICALIZE-NEXT: default { +// CHECK-CANONICALIZE: } +// CHECK-CANONICALIZE-NEXT: return +// CHECK-CANONICALIZE-NEXT: } func.func @clean_region_branch_op_remove_result(%arg0 : index, %arg1 : memref<i32>) { %non_live = scf.index_switch %arg0 -> i32 case 1 { @@ -539,10 +558,10 @@ module { } } -// CHECK-LABEL: func @test_zero_operands -// CHECK: memref.alloca_scope -// CHECK: memref.store -// CHECK-NOT: memref.alloca_scope.return +// CHECK-CANONICALIZE-LABEL: func @test_zero_operands +// CHECK-CANONICALIZE-NEXT: %[[c0:.*]] = arith.constant 0 +// CHECK-CANONICALIZE-NEXT: memref.store %[[c0]] +// CHECK-CANONICALIZE-NOT: memref.alloca_scope.return // ----- @@ -731,3 +750,49 @@ func.func @affine_loop_no_use_iv_has_side_effect_op() { // CHECK: } return } + +// ----- + +// CHECK-LABEL: func @scf_while_dead_iter_args() +// CHECK: %[[c5:.*]] = arith.constant 5 : i32 +// CHECK: %[[while:.*]]:2 = scf.while (%[[arg0:.*]] = %[[c5]]) : (i32) -> (i32, i32) { +// CHECK: vector.print %[[arg0]] +// CHECK: %[[cmpi:.*]] = arith.cmpi +// CHECK: %[[p0:.*]] = ub.poison : i32 +// CHECK: scf.condition(%[[cmpi]]) %[[arg0]], %[[p0]] +// CHECK: } do { +// CHECK: ^bb0(%[[arg1:.*]]: i32, %[[arg2:.*]]: i32): +// CHECK: %[[p1:.*]] = ub.poison : i32 +// CHECK: scf.yield %[[p1]] +// CHECK: } +// CHECK: return %[[while]]#0 + +// CHECK-CANONICALIZE-LABEL: func @scf_while_dead_iter_args() +// CHECK-CANONICALIZE: %[[c5:.*]] = arith.constant 5 : i32 +// CHECK-CANONICALIZE: %[[while:.*]] = scf.while (%[[arg0:.*]] = %[[c5]]) : (i32) -> i32 { +// CHECK-CANONICALIZE: vector.print %[[arg0]] +// CHECK-CANONICALIZE: %[[cmpi:.*]] = arith.cmpi +// CHECK-CANONICALIZE: scf.condition(%[[cmpi]]) %[[arg0]] +// CHECK-CANONICALIZE: } do { +// CHECK-CANONICALIZE: ^bb0(%[[arg1:.*]]: i32): +// CHECK-CANONICALIZE: %[[p0:.*]] = ub.poison : i32 +// CHECK-CANONICALIZE: scf.yield %[[p0]] +// CHECK-CANONICALIZE: } +// CHECK-CANONICALIZE: return %[[while]] +func.func @scf_while_dead_iter_args() -> i32 { + %c5 = arith.constant 5 : i32 + %result:2 = scf.while (%arg0 = %c5) : (i32) -> (i32, i32) { + vector.print %arg0 : i32 + // Note: This condition is always "false". (And the liveness analysis + // can figure that out.) + %cmp2 = arith.cmpi slt, %arg0, %c5 : i32 + scf.condition(%cmp2) %arg0, %arg0 : i32, i32 + } do { + ^bb0(%arg1: i32, %arg2: i32): + %x = scf.execute_region -> i32 { + scf.yield %arg2 : i32 + } + scf.yield %x : i32 + } + return %result#0 : i32 +} _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
