https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/174094
>From da0a853c603fd3acaddf502d6520c70376c77481 Mon Sep 17 00:00:00 2001 From: Matthias Springer <[email protected]> Date: Wed, 31 Dec 2025 14:07:51 +0000 Subject: [PATCH 1/2] [mlir][draft] Consolidate patterns into RegionBranchOpInterface patterns --- .../mlir/Interfaces/ControlFlowInterfaces.h | 2 + .../mlir/Interfaces/ControlFlowInterfaces.td | 9 + mlir/lib/Dialect/SCF/IR/SCF.cpp | 908 ++++-------------- mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 39 + mlir/test/Dialect/SCF/canonicalize.mlir | 12 +- 5 files changed, 243 insertions(+), 727 deletions(-) diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h index 566f4b8fadb5d..a7565f9f7bb78 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -188,6 +188,8 @@ LogicalResult verifyTypesAlongControlFlowEdges(Operation *op); /// possible successors.) Operands that not forwarded at all are not present in /// the mapping. using RegionBranchSuccessorMapping = DenseMap<OpOperand *, SmallVector<Value>>; +using RegionBranchInverseSuccessorMapping = + DenseMap<Value, SmallVector<OpOperand *>>; /// This class represents a successor of a region. A region successor can either /// be another region, or the parent operation. If the successor is a region, diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td index 2e654ba04ffe5..9366e5562b774 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -355,6 +355,15 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { ::mlir::RegionBranchSuccessorMapping &mapping, std::optional<::mlir::RegionBranchPoint> src = std::nullopt); + /// Build a mapping from successor inputs to successor operands. This is + /// the same as "getSuccessorOperandInputMapping", but inverted. + void getSuccessorInputOperandMapping( + ::mlir::RegionBranchInverseSuccessorMapping &mapping); + + /// Compute all values that a successor input could possibly have. If the + /// given value is not a successor input, return an empty set. + ::llvm::DenseSet<Value> computePossibleValuesOfSuccessorInput(::mlir::Value value); + /// Return all possible region branch points: the region branch op itself /// and all region branch terminators. ::llvm::SmallVector<::mlir::RegionBranchPoint> getAllRegionBranchPoints(); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 46d09abd89d69..6e1538676b1e5 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -27,6 +27,7 @@ #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" @@ -291,102 +292,9 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { } }; -// Pattern to eliminate ExecuteRegionOp results which forward external -// values from the region. In case there are multiple yield operations, -// all of them must have the same operands in order for the pattern to be -// applicable. -struct ExecuteRegionForwardingEliminator - : public OpRewritePattern<ExecuteRegionOp> { - using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(ExecuteRegionOp op, - PatternRewriter &rewriter) const override { - if (op.getNumResults() == 0) - return failure(); - - SmallVector<Operation *> yieldOps; - for (Block &block : op.getRegion()) { - if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator())) - yieldOps.push_back(yield.getOperation()); - } - - if (yieldOps.empty()) - return failure(); - - // Check if all yield operations have the same operands. - auto yieldOpsOperands = yieldOps[0]->getOperands(); - for (auto *yieldOp : yieldOps) { - if (yieldOp->getOperands() != yieldOpsOperands) - return failure(); - } - - SmallVector<Value> externalValues; - SmallVector<Value> internalValues; - SmallVector<Value> opResultsToReplaceWithExternalValues; - SmallVector<Value> opResultsToKeep; - for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) { - if (isValueFromInsideRegion(yieldedValue, op)) { - internalValues.push_back(yieldedValue); - opResultsToKeep.push_back(op.getResult(index)); - } else { - externalValues.push_back(yieldedValue); - opResultsToReplaceWithExternalValues.push_back(op.getResult(index)); - } - } - // No yielded external values - nothing to do. - if (externalValues.empty()) - return failure(); - - // There are yielded external values - create a new execute_region returning - // just the internal values. - SmallVector<Type> resultTypes; - for (Value value : internalValues) - resultTypes.push_back(value.getType()); - auto newOp = - ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes)); - newOp->setAttrs(op->getAttrs()); - - // Move old op's region to the new operation. - rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), - newOp.getRegion().end()); - - // Replace all yield operations with a new yield operation with updated - // results. scf.execute_region must have at least one yield operation. - for (auto *yieldOp : yieldOps) { - rewriter.setInsertionPoint(yieldOp); - rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, - ValueRange(internalValues)); - } - - // Replace the old operation with the external values directly. - rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues, - externalValues); - // Replace the old operation's remaining results with the new operation's - // results. - rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults()); - rewriter.eraseOp(op); - return success(); - } - -private: - bool isValueFromInsideRegion(Value value, - ExecuteRegionOp executeRegionOp) const { - // Check if the value is defined within the execute_region - if (Operation *defOp = value.getDefiningOp()) - return &executeRegionOp.getRegion() == defOp->getParentRegion(); - - // If it's a block argument, check if it's from within the region - if (BlockArgument blockArg = dyn_cast<BlockArgument>(value)) - return &executeRegionOp.getRegion() == blockArg.getParentRegion(); - - return false; // Value is from outside the region - } -}; - void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner, - ExecuteRegionForwardingEliminator>(context); + results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context); } void ExecuteRegionOp::getSuccessorRegions( @@ -1234,91 +1142,199 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> { } }; -/// Rewriting pattern that folds away cycles in the yield of a scf.for op. -/// -/// ``` -/// %res:2 = scf.for ... iter_args(%arg0 = %init, %arg1 = %init) { -/// ... -/// use %arg0, %arg1 -/// scf.yield %arg1, %arg0 -/// } -/// return %res#0, %res#1 -/// ``` -/// -/// folds into: -/// -/// ``` -/// scf.for ... iter_args() { -/// ... -/// use %init, %init -/// scf.yield -/// } -/// return %init, %init -/// ``` -struct ForOpYieldCyclesFolder : public OpRewritePattern<ForOp> { - using Base::Base; +/// Is a defined before b? +static bool isDefinedBefore(Value a, Value b) { + Region *aRegion = a.getParentRegion(); + Region *bRegion = b.getParentRegion(); - LogicalResult matchAndRewrite(ForOp op, + if (aRegion->getParentOp()->isProperAncestor(bRegion->getParentOp())) { + return true; + } + if (aRegion == bRegion) { + Block *aBlock = a.getParentBlock(); + Block *bBlock = b.getParentBlock(); + if (aBlock != bBlock) + return false; + if (isa<BlockArgument>(a)) + return true; + if (isa<BlockArgument>(b)) + return false; + return a.getDefiningOp()->isBeforeInBlock(b.getDefiningOp()); + } + + return false; +} + +// Try to make successor inputs dead by replacing their uses with values that +// are not successor inputs. This pattern enables additional canonicalization +// opportunities for RemoveDeadValues. +struct RemoveUsesOfIdenticalValues + : public OpInterfaceRewritePattern<RegionBranchOpInterface> { + using OpInterfaceRewritePattern< + RegionBranchOpInterface>::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(RegionBranchOpInterface op, PatternRewriter &rewriter) const override { - ValueRange yieldedValues = op.getYieldedValues(); - ValueRange initArgs = op.getInitArgs(); - ValueRange results = op.getResults(); - ValueRange regionIterArgs = op.getRegionIterArgs(); - Block *body = op.getBody(); + // TODO: ForallOp data flow is modeled incompletely. + if (isa<ForallOp>(op)) + return failure(); - unsigned numYieldedValues = op.getNumRegionIterArgs(); + // Gather all potential successor inputs. (Other values may also be + // included, but we're not doing anything with them.) + SmallVector<Value> values; + llvm::append_range(values, op->getResults()); + for (Region &r : op->getRegions()) + llvm::append_range(values, r.getArguments()); bool changed = false; - SmallVector<unsigned> cycle; - llvm::SmallBitVector visited(numYieldedValues, false); + for (Value value : values) { + if (value.use_empty()) + continue; + DenseSet<Value> possibleValues = + op.computePossibleValuesOfSuccessorInput(value); + if (possibleValues.size() == 1 && *possibleValues.begin() != value && + isDefinedBefore(*possibleValues.begin(), value)) { + // Value is same as another value. + rewriter.replaceAllUsesWith(value, *possibleValues.begin()); + changed = true; + } + } + return success(changed); + } +}; - // Go through all possible start points for the cycle. - for (auto start : llvm::seq(numYieldedValues)) { - if (visited[start]) +/// Pattern to remove dead values from region branch ops. +struct RemoveDeadValues + : public OpInterfaceRewritePattern<RegionBranchOpInterface> { + using OpInterfaceRewritePattern< + RegionBranchOpInterface>::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(RegionBranchOpInterface op, + PatternRewriter &rewriter) const override { + // TODO: ForallOp data flow is modeled incompletely. + if (isa<ForallOp>(op)) + return failure(); + + // Compute tied values: values that must come as a set. If you remove one, + // you must remove all. + RegionBranchSuccessorMapping operandToInputs; + op.getSuccessorOperandInputMapping(operandToInputs); + llvm::EquivalenceClasses<Value> tiedSuccessorInputs; + for (const auto &[operand, inputs] : operandToInputs) { + assert(!inputs.empty() && "expected non-empty inputs"); + Value firstInput = inputs.front(); + tiedSuccessorInputs.insert(firstInput); + for (Value nextInput : llvm::drop_begin(inputs)) + tiedSuccessorInputs.unionSets(firstInput, nextInput); + } + + // Determine which values to remove and group them by block and operation. + SmallVector<Value> valuesToRemove; + DenseMap<Block *, BitVector> blockArgsToRemove; + DenseMap<Operation *, BitVector> resultsToRemove; + for (auto it = tiedSuccessorInputs.begin(), e = tiedSuccessorInputs.end(); + it != e; ++it) { + if (!(*it)->isLeader()) continue; - cycle.clear(); - unsigned current = start; - bool validCycle = true; - Value initValue = initArgs[start]; - // Go through yield -> block arg -> yield cycles and check if all values - // are always equal to the init. - while (!visited[current]) { - cycle.push_back(current); - visited[current] = true; - - // Find whether this yield is from a region iter arg. - auto yieldedValue = yieldedValues[current]; - if (auto arg = dyn_cast<BlockArgument>(yieldedValue); - !arg || arg.getOwner() != body) { - validCycle = false; + // Value can be removed if it is dead and all other tied values are also + // dead. + bool allDead = true; + for (auto memberIt = tiedSuccessorInputs.member_begin(**it); + memberIt != tiedSuccessorInputs.member_end(); ++memberIt) { + if (!memberIt->use_empty()) { + allDead = false; break; } + } + if (!allDead) + continue; - // Next yield position. - current = cast<BlockArgument>(yieldedValue).getArgNumber() - - op.getNumInductionVars(); - - // Check if next position has the same init value. - if (initArgs[current] != initValue) { - validCycle = false; - break; + // Group values by block and operation. + for (auto memberIt = tiedSuccessorInputs.member_begin(**it); + memberIt != tiedSuccessorInputs.member_end(); ++memberIt) { + if (auto arg = dyn_cast<BlockArgument>(*memberIt)) { + BitVector &vector = + blockArgsToRemove + .try_emplace(arg.getOwner(), + arg.getOwner()->getNumArguments(), false) + .first->second; + vector.set(arg.getArgNumber()); + } else { + OpResult result = cast<OpResult>(*memberIt); + BitVector &vector = + resultsToRemove + .try_emplace(result.getDefiningOp(), + result.getDefiningOp()->getNumResults(), false) + .first->second; + vector.set(result.getResultNumber()); } + valuesToRemove.push_back(*memberIt); } + } - // If we found a valid cycle (yielding own iter arg forms cycle of length - // 1), all values in it are always equal to initValue. - if (validCycle) { - changed = true; - for (unsigned idx : cycle) { - // This will leave region args and results dead so other - // canonicalization patterns can clean them up. - rewriter.replaceAllUsesWith(regionIterArgs[idx], initValue); - rewriter.replaceAllUsesWith(results[idx], initValue); + if (valuesToRemove.empty()) + return rewriter.notifyMatchFailure(op, "no values to remove"); + + // Find operands that must be removed together with the values. + RegionBranchInverseSuccessorMapping inputsToOperands; + op.getSuccessorInputOperandMapping(inputsToOperands); + DenseMap<Operation *, llvm::BitVector> operandsToRemove; + for (Value value : valuesToRemove) { + for (OpOperand *operand : inputsToOperands[value]) { + BitVector &vector = + operandsToRemove + .try_emplace(operand->getOwner(), + operand->getOwner()->getNumOperands(), false) + .first->second; + vector.set(operand->getOperandNumber()); + } + } + + // Erase operands. + for (auto [op, operands] : operandsToRemove) { + rewriter.modifyOpInPlace(op, [&]() { op->eraseOperands(operands); }); + } + + // Erase block arguments. + for (auto [block, blockArgs] : blockArgsToRemove) { + rewriter.modifyOpInPlace(block->getParentOp(), + [&]() { block->eraseArguments(blockArgs); }); + } + + // Erase op results. + // TODO: Can we move this to RewriterBase, so we have a uniform API, + // similar to eraseArguments? + for (auto [op, resultsToErase] : resultsToRemove) { + rewriter.setInsertionPoint(op); + SmallVector<Type> newResultTypes; + for (OpResult result : op->getResults()) + if (!resultsToErase[result.getResultNumber()]) + newResultTypes.push_back(result.getType()); + OperationState state(op->getLoc(), op->getName().getStringRef(), + op->getOperands(), newResultTypes, op->getAttrs()); + for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) + state.addRegion(); + Operation *newOp = rewriter.create(state); + for (const auto &[index, region] : llvm::enumerate(op->getRegions())) { + // Move all blocks of `region` into `newRegion`. + Region &newRegion = newOp->getRegion(index); + rewriter.inlineRegionBefore(region, newRegion, newRegion.begin()); + } + + SmallVector<Value> newResults; + unsigned nextLiveResult = 0; + for (auto [index, result] : llvm::enumerate(op->getResults())) { + if (!resultsToErase[index]) { + newResults.push_back(newOp->getResult(nextLiveResult++)); + } else { + newResults.push_back(Value()); } } + rewriter.replaceOp(op, newResults); } - return success(changed); + + return success(); } }; @@ -1326,8 +1342,11 @@ struct ForOpYieldCyclesFolder : public OpRewritePattern<ForOp> { void ForOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder, - ForOpYieldCyclesFolder>(context); + // TODO: ForOpIterArgsFolder also removes duplicate bbargs, can this be done + // in the two new patterns? + results.add</*ForOpIterArgsFolder, */ SimplifyTrivialLoops, + ForOpTensorCastFolder, RemoveUsesOfIdenticalValues, + RemoveDeadValues>(context); } std::optional<APInt> ForOp::getConstantStep() { @@ -2495,61 +2514,6 @@ void IfOp::getRegionInvocationBounds( } namespace { -// Pattern to remove unused IfOp results. -struct RemoveUnusedResults : public OpRewritePattern<IfOp> { - using OpRewritePattern<IfOp>::OpRewritePattern; - - void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults, - PatternRewriter &rewriter) const { - // Move all operations to the destination block. - rewriter.mergeBlocks(source, dest); - // Replace the yield op by one that returns only the used values. - auto yieldOp = cast<scf::YieldOp>(dest->getTerminator()); - SmallVector<Value, 4> usedOperands; - llvm::transform(usedResults, std::back_inserter(usedOperands), - [&](OpResult result) { - return yieldOp.getOperand(result.getResultNumber()); - }); - rewriter.modifyOpInPlace(yieldOp, - [&]() { yieldOp->setOperands(usedOperands); }); - } - - LogicalResult matchAndRewrite(IfOp op, - PatternRewriter &rewriter) const override { - // Compute the list of used results. - SmallVector<OpResult, 4> usedResults; - llvm::copy_if(op.getResults(), std::back_inserter(usedResults), - [](OpResult result) { return !result.use_empty(); }); - - // Replace the operation if only a subset of its results have uses. - if (usedResults.size() == op.getNumResults()) - return failure(); - - // Compute the result types of the replacement operation. - SmallVector<Type, 4> newTypes; - llvm::transform(usedResults, std::back_inserter(newTypes), - [](OpResult result) { return result.getType(); }); - - // Create a replacement operation with empty then and else regions. - auto newOp = - IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition()); - rewriter.createBlock(&newOp.getThenRegion()); - rewriter.createBlock(&newOp.getElseRegion()); - - // Move the bodies and replace the terminators (note there is a then and - // an else region since the operation returns results). - transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter); - transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter); - - // Replace the operation by the new one. - SmallVector<Value, 4> repResults(op.getNumResults()); - for (const auto &en : llvm::enumerate(usedResults)) - repResults[en.value().getResultNumber()] = newOp.getResult(en.index()); - rewriter.replaceOp(op, repResults); - return success(); - } -}; - struct RemoveStaticCondition : public OpRewritePattern<IfOp> { using OpRewritePattern<IfOp>::OpRewritePattern; @@ -3120,8 +3084,8 @@ void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add<CombineIfs, CombineNestedIfs, ConditionPropagation, ConvertTrivialIfToSelect, RemoveEmptyElseBranch, - RemoveStaticCondition, RemoveUnusedResults, - ReplaceIfYieldWithConditionOrValue>(context); + RemoveStaticCondition, ReplaceIfYieldWithConditionOrValue>( + context); } Block *IfOp::thenBlock() { return &getThenRegion().back(); } @@ -3959,390 +3923,6 @@ struct WhileConditionTruth : public OpRewritePattern<WhileOp> { } }; -/// Remove loop invariant arguments from `before` block of scf.while. -/// A before block argument is considered loop invariant if :- -/// 1. i-th yield operand is equal to the i-th while operand. -/// 2. i-th yield operand is k-th after block argument which is (k+1)-th -/// condition operand AND this (k+1)-th condition operand is equal to i-th -/// iter argument/while operand. -/// For the arguments which are removed, their uses inside scf.while -/// are replaced with their corresponding initial value. -/// -/// Eg: -/// INPUT :- -/// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b, -/// ..., %argN_before = %N) -/// { -/// ... -/// scf.condition(%cond) %arg1_before, %arg0_before, -/// %arg2_before, %arg0_before, ... -/// } do { -/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2, -/// ..., %argK_after): -/// ... -/// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN -/// } -/// -/// OUTPUT :- -/// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before = -/// %N) -/// { -/// ... -/// scf.condition(%cond) %b, %a, %arg2_before, %a, ... -/// } do { -/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2, -/// ..., %argK_after): -/// ... -/// scf.yield %arg1_after, ..., %argN -/// } -/// -/// EXPLANATION: -/// We iterate over each yield operand. -/// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand -/// %arg0_before, which in turn is the 0-th iter argument. So we -/// remove 0-th before block argument and yield operand, and replace -/// all uses of the 0-th before block argument with its initial value -/// %a. -/// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial -/// value. So we remove this operand and the corresponding before -/// block argument and replace all uses of 1-th before block argument -/// with %b. -struct RemoveLoopInvariantArgsFromBeforeBlock - : public OpRewritePattern<WhileOp> { - using OpRewritePattern<WhileOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(WhileOp op, - PatternRewriter &rewriter) const override { - Block &afterBlock = *op.getAfterBody(); - Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments(); - ConditionOp condOp = op.getConditionOp(); - OperandRange condOpArgs = condOp.getArgs(); - Operation *yieldOp = afterBlock.getTerminator(); - ValueRange yieldOpArgs = yieldOp->getOperands(); - - bool canSimplify = false; - for (const auto &it : - llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) { - auto index = static_cast<unsigned>(it.index()); - auto [initVal, yieldOpArg] = it.value(); - // If i-th yield operand is equal to the i-th operand of the scf.while, - // the i-th before block argument is a loop invariant. - if (yieldOpArg == initVal) { - canSimplify = true; - break; - } - // If the i-th yield operand is k-th after block argument, then we check - // if the (k+1)-th condition op operand is equal to either the i-th before - // block argument or the initial value of i-th before block argument. If - // the comparison results `true`, i-th before block argument is a loop - // invariant. - auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg); - if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { - Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; - if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { - canSimplify = true; - break; - } - } - } - - if (!canSimplify) - return failure(); - - SmallVector<Value> newInitArgs, newYieldOpArgs; - DenseMap<unsigned, Value> beforeBlockInitValMap; - SmallVector<Location> newBeforeBlockArgLocs; - for (const auto &it : - llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) { - auto index = static_cast<unsigned>(it.index()); - auto [initVal, yieldOpArg] = it.value(); - - // If i-th yield operand is equal to the i-th operand of the scf.while, - // the i-th before block argument is a loop invariant. - if (yieldOpArg == initVal) { - beforeBlockInitValMap.insert({index, initVal}); - continue; - } else { - // If the i-th yield operand is k-th after block argument, then we check - // if the (k+1)-th condition op operand is equal to either the i-th - // before block argument or the initial value of i-th before block - // argument. If the comparison results `true`, i-th before block - // argument is a loop invariant. - auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg); - if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { - Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; - if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { - beforeBlockInitValMap.insert({index, initVal}); - continue; - } - } - } - newInitArgs.emplace_back(initVal); - newYieldOpArgs.emplace_back(yieldOpArg); - newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc()); - } - - { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(yieldOp); - rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs); - } - - auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(), - newInitArgs); - - Block &newBeforeBlock = *rewriter.createBlock( - &newWhile.getBefore(), /*insertPt*/ {}, - ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs); - - Block &beforeBlock = *op.getBeforeBody(); - SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments()); - // For each i-th before block argument we find it's replacement value as :- - // 1. If i-th before block argument is a loop invariant, we fetch it's - // initial value from `beforeBlockInitValMap` by querying for key `i`. - // 2. Else we fetch j-th new before block argument as the replacement - // value of i-th before block argument. - for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) { - // If the index 'i' argument was a loop invariant we fetch it's initial - // value from `beforeBlockInitValMap`. - if (beforeBlockInitValMap.count(i) != 0) - newBeforeBlockArgs[i] = beforeBlockInitValMap[i]; - else - newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++); - } - - rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs); - rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(), - newWhile.getAfter().begin()); - - rewriter.replaceOp(op, newWhile.getResults()); - return success(); - } -}; - -/// Remove loop invariant value from result (condition op) of scf.while. -/// A value is considered loop invariant if the final value yielded by -/// scf.condition is defined outside of the `before` block. We remove the -/// corresponding argument in `after` block and replace the use with the value. -/// We also replace the use of the corresponding result of scf.while with the -/// value. -/// -/// Eg: -/// INPUT :- -/// %res_input:K = scf.while <...> iter_args(%arg0_before = , ..., -/// %argN_before = %N) { -/// ... -/// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ... -/// } do { -/// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after): -/// ... -/// some_func(%arg1_after) -/// ... -/// scf.yield %arg0_after, %arg2_after, ..., %argN_after -/// } -/// -/// OUTPUT :- -/// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) { -/// ... -/// scf.condition(%cond) %arg0, %arg1, ..., %argM -/// } do { -/// ^bb0(%arg0, %arg3, ..., %argM): -/// ... -/// some_func(%a) -/// ... -/// scf.yield %arg0, %b, ..., %argN -/// } -/// -/// EXPLANATION: -/// 1. The 1-th and 2-th operand of scf.condition are defined outside the -/// before block of scf.while, so they get removed. -/// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are -/// replaced by %b. -/// 3. The corresponding after block argument %arg1_after's uses are -/// replaced by %a and %arg2_after's uses are replaced by %b. -struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> { - using OpRewritePattern<WhileOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(WhileOp op, - PatternRewriter &rewriter) const override { - Block &beforeBlock = *op.getBeforeBody(); - ConditionOp condOp = op.getConditionOp(); - OperandRange condOpArgs = condOp.getArgs(); - - bool canSimplify = false; - for (Value condOpArg : condOpArgs) { - // Those values not defined within `before` block will be considered as - // loop invariant values. We map the corresponding `index` with their - // value. - if (condOpArg.getParentBlock() != &beforeBlock) { - canSimplify = true; - break; - } - } - - if (!canSimplify) - return failure(); - - Block::BlockArgListType afterBlockArgs = op.getAfterArguments(); - - SmallVector<Value> newCondOpArgs; - SmallVector<Type> newAfterBlockType; - DenseMap<unsigned, Value> condOpInitValMap; - SmallVector<Location> newAfterBlockArgLocs; - for (const auto &it : llvm::enumerate(condOpArgs)) { - auto index = static_cast<unsigned>(it.index()); - Value condOpArg = it.value(); - // Those values not defined within `before` block will be considered as - // loop invariant values. We map the corresponding `index` with their - // value. - if (condOpArg.getParentBlock() != &beforeBlock) { - condOpInitValMap.insert({index, condOpArg}); - } else { - newCondOpArgs.emplace_back(condOpArg); - newAfterBlockType.emplace_back(condOpArg.getType()); - newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc()); - } - } - - { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(condOp); - rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(), - newCondOpArgs); - } - - auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType, - op.getOperands()); - - Block &newAfterBlock = - *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {}, - newAfterBlockType, newAfterBlockArgLocs); - - Block &afterBlock = *op.getAfterBody(); - // Since a new scf.condition op was created, we need to fetch the new - // `after` block arguments which will be used while replacing operations of - // previous scf.while's `after` blocks. We'd also be fetching new result - // values too. - SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments()); - SmallVector<Value> newWhileResults(afterBlock.getNumArguments()); - for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) { - Value afterBlockArg, result; - // If index 'i' argument was loop invariant we fetch it's value from the - // `condOpInitMap` map. - if (condOpInitValMap.count(i) != 0) { - afterBlockArg = condOpInitValMap[i]; - result = afterBlockArg; - } else { - afterBlockArg = newAfterBlock.getArgument(j); - result = newWhile.getResult(j); - j++; - } - newAfterBlockArgs[i] = afterBlockArg; - newWhileResults[i] = result; - } - - rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs); - rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(), - newWhile.getBefore().begin()); - - rewriter.replaceOp(op, newWhileResults); - return success(); - } -}; - -/// Remove WhileOp results that are also unused in 'after' block. -/// -/// %0:2 = scf.while () : () -> (i32, i64) { -/// %condition = "test.condition"() : () -> i1 -/// %v1 = "test.get_some_value"() : () -> i32 -/// %v2 = "test.get_some_value"() : () -> i64 -/// scf.condition(%condition) %v1, %v2 : i32, i64 -/// } do { -/// ^bb0(%arg0: i32, %arg1: i64): -/// "test.use"(%arg0) : (i32) -> () -/// scf.yield -/// } -/// return %0#0 : i32 -/// -/// becomes -/// %0 = scf.while () : () -> (i32) { -/// %condition = "test.condition"() : () -> i1 -/// %v1 = "test.get_some_value"() : () -> i32 -/// %v2 = "test.get_some_value"() : () -> i64 -/// scf.condition(%condition) %v1 : i32 -/// } do { -/// ^bb0(%arg0: i32): -/// "test.use"(%arg0) : (i32) -> () -/// scf.yield -/// } -/// return %0 : i32 -struct WhileUnusedResult : public OpRewritePattern<WhileOp> { - using OpRewritePattern<WhileOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(WhileOp op, - PatternRewriter &rewriter) const override { - auto term = op.getConditionOp(); - auto afterArgs = op.getAfterArguments(); - auto termArgs = term.getArgs(); - - // Collect results mapping, new terminator args and new result types. - SmallVector<unsigned> newResultsIndices; - SmallVector<Type> newResultTypes; - SmallVector<Value> newTermArgs; - SmallVector<Location> newArgLocs; - bool needUpdate = false; - for (const auto &it : - llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) { - auto i = static_cast<unsigned>(it.index()); - Value result = std::get<0>(it.value()); - Value afterArg = std::get<1>(it.value()); - Value termArg = std::get<2>(it.value()); - if (result.use_empty() && afterArg.use_empty()) { - needUpdate = true; - } else { - newResultsIndices.emplace_back(i); - newTermArgs.emplace_back(termArg); - newResultTypes.emplace_back(result.getType()); - newArgLocs.emplace_back(result.getLoc()); - } - } - - if (!needUpdate) - return failure(); - - { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(term); - rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(), - newTermArgs); - } - - auto newWhile = - WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits()); - - Block &newAfterBlock = *rewriter.createBlock( - &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs); - - // Build new results list and new after block args (unused entries will be - // null). - SmallVector<Value> newResults(op.getNumResults()); - SmallVector<Value> newAfterBlockArgs(op.getNumResults()); - for (const auto &it : llvm::enumerate(newResultsIndices)) { - newResults[it.value()] = newWhile.getResult(it.index()); - newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index()); - } - - rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(), - newWhile.getBefore().begin()); - - Block &afterBlock = *op.getAfterBody(); - rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs); - - rewriter.replaceOp(op, newResults); - return success(); - } -}; - /// Replace operations equivalent to the condition in the do block with true, /// since otherwise the block would not be evaluated. /// @@ -4407,65 +3987,6 @@ struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> { } }; -/// Remove unused init/yield args. -struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> { - using OpRewritePattern<WhileOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(WhileOp op, - PatternRewriter &rewriter) const override { - - if (!llvm::any_of(op.getBeforeArguments(), - [](Value arg) { return arg.use_empty(); })) - return rewriter.notifyMatchFailure(op, "No args to remove"); - - YieldOp yield = op.getYieldOp(); - - // Collect results mapping, new terminator args and new result types. - SmallVector<Value> newYields; - SmallVector<Value> newInits; - llvm::BitVector argsToErase; - - size_t argsCount = op.getBeforeArguments().size(); - newYields.reserve(argsCount); - newInits.reserve(argsCount); - argsToErase.reserve(argsCount); - for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip( - op.getBeforeArguments(), yield.getOperands(), op.getInits())) { - if (beforeArg.use_empty()) { - argsToErase.push_back(true); - } else { - argsToErase.push_back(false); - newYields.emplace_back(yieldValue); - newInits.emplace_back(initValue); - } - } - - Block &beforeBlock = *op.getBeforeBody(); - Block &afterBlock = *op.getAfterBody(); - - beforeBlock.eraseArguments(argsToErase); - - Location loc = op.getLoc(); - auto newWhileOp = - WhileOp::create(rewriter, loc, op.getResultTypes(), newInits, - /*beforeBody*/ nullptr, /*afterBody*/ nullptr); - Block &newBeforeBlock = *newWhileOp.getBeforeBody(); - Block &newAfterBlock = *newWhileOp.getAfterBody(); - - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(yield); - rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields); - - rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, - newBeforeBlock.getArguments()); - rewriter.mergeBlocks(&afterBlock, &newAfterBlock, - newAfterBlock.getArguments()); - - rewriter.replaceOp(op, newWhileOp.getResults()); - return success(); - } -}; - /// Remove duplicated ConditionOp args. struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> { using OpRewritePattern::OpRewritePattern; @@ -4618,11 +4139,8 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> { void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<RemoveLoopInvariantArgsFromBeforeBlock, - RemoveLoopInvariantValueYielded, WhileConditionTruth, - WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults, - WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>( - context); + results.add<WhileConditionTruth, WhileCmpCond, WhileRemoveDuplicatedResults, + WhileOpAlignBeforeArgs, WhileMoveIfDown>(context); } //===----------------------------------------------------------------------===// @@ -4797,59 +4315,9 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> { } }; -/// Canonicalization patterns that folds away dead results of -/// "scf.index_switch" ops. -struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> { - using OpRewritePattern<IndexSwitchOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(IndexSwitchOp op, - PatternRewriter &rewriter) const override { - // Find dead results. - BitVector deadResults(op.getNumResults(), false); - SmallVector<Type> newResultTypes; - for (auto [idx, result] : llvm::enumerate(op.getResults())) { - if (!result.use_empty()) { - newResultTypes.push_back(result.getType()); - } else { - deadResults[idx] = true; - } - } - if (!deadResults.any()) - return rewriter.notifyMatchFailure(op, "no dead results to fold"); - - // Create new op without dead results and inline case regions. - auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes, - op.getArg(), op.getCases(), - op.getCaseRegions().size()); - auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) { - rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin()); - // Remove respective operands from yield op. - Operation *terminator = newRegion.front().getTerminator(); - assert(isa<YieldOp>(terminator) && "expected yield op"); - rewriter.modifyOpInPlace( - terminator, [&]() { terminator->eraseOperands(deadResults); }); - }; - for (auto [oldRegion, newRegion] : - llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions())) - inlineCaseRegion(oldRegion, newRegion); - inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion()); - - // Replace op with new op. - SmallVector<Value> newResults(op.getNumResults(), Value()); - unsigned nextNewResult = 0; - for (unsigned idx = 0; idx < op.getNumResults(); ++idx) { - if (deadResults[idx]) - continue; - newResults[idx] = newOp.getResult(nextNewResult++); - } - rewriter.replaceOp(op, newResults); - return success(); - } -}; - void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context); + results.add<FoldConstantCase>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp index d393ddb8d8336..ed94205d32f19 100644 --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -521,6 +521,45 @@ void RegionBranchOpInterface::getSuccessorOperandInputMapping( } } +void RegionBranchOpInterface::getSuccessorInputOperandMapping( + RegionBranchInverseSuccessorMapping &mapping) { + RegionBranchSuccessorMapping operandToInputs; + getSuccessorOperandInputMapping(operandToInputs); + for (const auto &[operand, inputs] : operandToInputs) { + for (Value input : inputs) + mapping[input].push_back(operand); + } +} + +DenseSet<Value> +RegionBranchOpInterface::computePossibleValuesOfSuccessorInput(Value value) { + RegionBranchInverseSuccessorMapping inputToOperands; + getSuccessorInputOperandMapping(inputToOperands); + + DenseSet<Value> possibleValues; + DenseSet<Value> visited; + SmallVector<Value> worklist; + + // Starting with the given value, trace back all predecessor values (i.e., + // preceding successor operands) and add them to the set of possible values. + // If the successor operand is again a successor input, do not add it to + // result set, but instead continue the traversal. + worklist.push_back(value); + while (!worklist.empty()) { + Value next = worklist.pop_back_val(); + auto it = inputToOperands.find(next); + if (it == inputToOperands.end()) { + possibleValues.insert(next); + continue; + } + for (OpOperand *operand : it->second) + if (visited.insert(operand->get()).second) + worklist.push_back(operand->get()); + } + + return possibleValues; +} + SmallVector<RegionBranchPoint> RegionBranchOpInterface::getAllRegionBranchPoints() { SmallVector<RegionBranchPoint> branchPoints; diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 984ea10f7e540..0420d1c018d76 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1736,11 +1736,11 @@ module { // Test case with multiple scf.yield ops with at least one different operand, then no change. -// CHECK: %[[VAL_3:.*]]:2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline { +// CHECK: %[[VAL_3:.*]] = scf.execute_region -> memref<1x120xui8> no_inline { // CHECK: ^bb1: -// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8> +// CHECK: scf.yield %{{.*}} : memref<1x120xui8> // CHECK: ^bb2: -// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8> +// CHECK: scf.yield %{{.*}} : memref<1x120xui8> // CHECK: } module { @@ -2214,16 +2214,14 @@ func.func @iter_args_cycles_non_cycle_start(%lb : index, %ub : index, %step : in // CHECK-SAME: %[[arg0:.*]]: index // CHECK-DAG: %[[c10:.*]] = arith.constant 10 // CHECK-DAG: %[[c11:.*]] = arith.constant 11 -// CHECK: %[[switch:.*]] = scf.index_switch %[[arg0]] -> index +// CHECK: scf.index_switch %[[arg0]] // CHECK: case 1 { // CHECK: memref.store %[[c10]] -// CHECK: scf.yield %[[arg0]] : index // CHECK: } // CHECK: default { // CHECK: memref.store %[[c11]] -// CHECK: scf.yield %[[arg0]] : index // CHECK: } -// CHECK: return %[[switch]] +// CHECK: return %[[arg0]] func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index { %non_live, %live = scf.index_switch %arg0 -> i32, index case 1 { >From 447cadee988ede17eda4df12bc316c8b7e688807 Mon Sep 17 00:00:00 2001 From: Matthias Springer <[email protected]> Date: Wed, 31 Dec 2025 16:00:05 +0000 Subject: [PATCH 2/2] fix some tests --- .../mlir/Dialect/SparseTensor/IR/SparseTensorOps.td | 7 ++++--- mlir/test/Dialect/Vector/vector-warp-distribute.mlir | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index a61d90a0c39b1..f41b3694d9c79 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -1304,9 +1304,10 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu let hasVerifier = 1; } -def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator, - ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp", - "ForeachOp", "IterateOp", "CoIterateOp"]>]> { +def SparseTensor_YieldOp : SparseTensor_Op<"yield", + [Pure, Terminator, ReturnLike, + ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp", + "ForeachOp", "IterateOp", "CoIterateOp"]>]> { let summary = "Yield from sparse_tensor set-like operations"; let description = [{ Yields a value from within a `binary`, `unary`, `reduce`, diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 135db02d543ef..18fb6852f6875 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1330,11 +1330,11 @@ func.func @vector_insert_1d_broadcast(%laneid: index, %pos: index) -> (vector<96 // ----- // CHECK-PROP-LABEL: func @vector_insert_0d( -// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> (vector<f32>, f32) +// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (f32) // CHECK-PROP: %[[VEC:.*]] = "some_def" // CHECK-PROP: %[[VAL:.*]] = "another_def" -// CHECK-PROP: gpu.yield %[[VEC]], %[[VAL]] -// CHECK-PROP: vector.broadcast %[[W]]#1 : f32 to vector<f32> +// CHECK-PROP: gpu.yield %[[VAL]] +// CHECK-PROP: vector.broadcast %[[W]] : f32 to vector<f32> func.func @vector_insert_0d(%laneid: index) -> (vector<f32>) { %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<f32>) { %0 = "some_def"() : () -> (vector<f32>) _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
