https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/174094
>From 8b2bcb7b6652272bc614d37ca2305000797b2b3f Mon Sep 17 00:00:00 2001 From: Matthias Springer <[email protected]> Date: Wed, 31 Dec 2025 14:07:51 +0000 Subject: [PATCH] [mlir][draft] Consolidate patterns into RegionBranchOpInterface patterns fix some tests --- .../SparseTensor/IR/SparseTensorOps.td | 7 +- .../mlir/Interfaces/ControlFlowInterfaces.h | 2 + .../mlir/Interfaces/ControlFlowInterfaces.td | 9 + mlir/lib/Dialect/SCF/IR/SCF.cpp | 1134 ++++------------- mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 39 + mlir/test/Dialect/SCF/canonicalize.mlir | 24 +- .../Dialect/SparseTensor/sparse_kernels.mlir | 16 +- .../test/Dialect/SparseTensor/sparse_out.mlir | 34 +- .../Vector/vector-warp-distribute.mlir | 6 +- mlir/test/Transforms/remove-dead-values.mlir | 8 +- 10 files changed, 377 insertions(+), 902 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index a61d90a0c39b1..f41b3694d9c79 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -1304,9 +1304,10 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu let hasVerifier = 1; } -def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator, - ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp", - "ForeachOp", "IterateOp", "CoIterateOp"]>]> { +def SparseTensor_YieldOp : SparseTensor_Op<"yield", + [Pure, Terminator, ReturnLike, + ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp", + "ForeachOp", "IterateOp", "CoIterateOp"]>]> { let summary = "Yield from sparse_tensor set-like operations"; let description = [{ Yields a value from within a `binary`, `unary`, `reduce`, diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h index 566f4b8fadb5d..a7565f9f7bb78 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -188,6 +188,8 @@ LogicalResult verifyTypesAlongControlFlowEdges(Operation *op); /// possible successors.) Operands that not forwarded at all are not present in /// the mapping. using RegionBranchSuccessorMapping = DenseMap<OpOperand *, SmallVector<Value>>; +using RegionBranchInverseSuccessorMapping = + DenseMap<Value, SmallVector<OpOperand *>>; /// This class represents a successor of a region. A region successor can either /// be another region, or the parent operation. If the successor is a region, diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td index 2e654ba04ffe5..9366e5562b774 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -355,6 +355,15 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { ::mlir::RegionBranchSuccessorMapping &mapping, std::optional<::mlir::RegionBranchPoint> src = std::nullopt); + /// Build a mapping from successor inputs to successor operands. This is + /// the same as "getSuccessorOperandInputMapping", but inverted. + void getSuccessorInputOperandMapping( + ::mlir::RegionBranchInverseSuccessorMapping &mapping); + + /// Compute all values that a successor input could possibly have. If the + /// given value is not a successor input, return an empty set. + ::llvm::DenseSet<Value> computePossibleValuesOfSuccessorInput(::mlir::Value value); + /// Return all possible region branch points: the region branch op itself /// and all region branch terminators. ::llvm::SmallVector<::mlir::RegionBranchPoint> getAllRegionBranchPoints(); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 0a123112cf68f..06b542d1c1dae 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -27,6 +27,7 @@ #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" @@ -291,102 +292,9 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { } }; -// Pattern to eliminate ExecuteRegionOp results which forward external -// values from the region. In case there are multiple yield operations, -// all of them must have the same operands in order for the pattern to be -// applicable. -struct ExecuteRegionForwardingEliminator - : public OpRewritePattern<ExecuteRegionOp> { - using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(ExecuteRegionOp op, - PatternRewriter &rewriter) const override { - if (op.getNumResults() == 0) - return failure(); - - SmallVector<Operation *> yieldOps; - for (Block &block : op.getRegion()) { - if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator())) - yieldOps.push_back(yield.getOperation()); - } - - if (yieldOps.empty()) - return failure(); - - // Check if all yield operations have the same operands. - auto yieldOpsOperands = yieldOps[0]->getOperands(); - for (auto *yieldOp : yieldOps) { - if (yieldOp->getOperands() != yieldOpsOperands) - return failure(); - } - - SmallVector<Value> externalValues; - SmallVector<Value> internalValues; - SmallVector<Value> opResultsToReplaceWithExternalValues; - SmallVector<Value> opResultsToKeep; - for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) { - if (isValueFromInsideRegion(yieldedValue, op)) { - internalValues.push_back(yieldedValue); - opResultsToKeep.push_back(op.getResult(index)); - } else { - externalValues.push_back(yieldedValue); - opResultsToReplaceWithExternalValues.push_back(op.getResult(index)); - } - } - // No yielded external values - nothing to do. - if (externalValues.empty()) - return failure(); - - // There are yielded external values - create a new execute_region returning - // just the internal values. - SmallVector<Type> resultTypes; - for (Value value : internalValues) - resultTypes.push_back(value.getType()); - auto newOp = - ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes)); - newOp->setAttrs(op->getAttrs()); - - // Move old op's region to the new operation. - rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), - newOp.getRegion().end()); - - // Replace all yield operations with a new yield operation with updated - // results. scf.execute_region must have at least one yield operation. - for (auto *yieldOp : yieldOps) { - rewriter.setInsertionPoint(yieldOp); - rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, - ValueRange(internalValues)); - } - - // Replace the old operation with the external values directly. - rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues, - externalValues); - // Replace the old operation's remaining results with the new operation's - // results. - rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults()); - rewriter.eraseOp(op); - return success(); - } - -private: - bool isValueFromInsideRegion(Value value, - ExecuteRegionOp executeRegionOp) const { - // Check if the value is defined within the execute_region - if (Operation *defOp = value.getDefiningOp()) - return &executeRegionOp.getRegion() == defOp->getParentRegion(); - - // If it's a block argument, check if it's from within the region - if (BlockArgument blockArg = dyn_cast<BlockArgument>(value)) - return &executeRegionOp.getRegion() == blockArg.getParentRegion(); - - return false; // Value is from outside the region - } -}; - void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner, - ExecuteRegionForwardingEliminator>(context); + results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context); } void ExecuteRegionOp::getSuccessorRegions( @@ -989,146 +897,6 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, } namespace { -// Fold away ForOp iter arguments when: -// 1) The op yields the iter arguments. -// 2) The argument's corresponding outer region iterators (inputs) are yielded. -// 3) The iter arguments have no use and the corresponding (operation) results -// have no use. -// -// These arguments must be defined outside of the ForOp region and can just be -// forwarded after simplifying the op inits, yields and returns. -// -// The implementation uses `inlineBlockBefore` to steal the content of the -// original ForOp and avoid cloning. -struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> { - using OpRewritePattern<scf::ForOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::ForOp forOp, - PatternRewriter &rewriter) const final { - bool canonicalize = false; - - // An internal flat vector of block transfer - // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to - // transformed block argument mappings. This plays the role of a - // IRMapping for the particular use case of calling into - // `inlineBlockBefore`. - int64_t numResults = forOp.getNumResults(); - SmallVector<bool, 4> keepMask; - keepMask.reserve(numResults); - SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues, - newResultValues; - newBlockTransferArgs.reserve(1 + numResults); - newBlockTransferArgs.push_back(Value()); // iv placeholder with null value - newIterArgs.reserve(forOp.getInitArgs().size()); - newYieldValues.reserve(numResults); - newResultValues.reserve(numResults); - DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg; - for (auto [init, arg, result, yielded] : - llvm::zip(forOp.getInitArgs(), // iter from outside - forOp.getRegionIterArgs(), // iter inside region - forOp.getResults(), // op results - forOp.getYieldedValues() // iter yield - )) { - // Forwarded is `true` when: - // 1) The region `iter` argument is yielded. - // 2) The region `iter` argument the corresponding input is yielded. - // 3) The region `iter` argument has no use, and the corresponding op - // result has no use. - bool forwarded = (arg == yielded) || (init == yielded) || - (arg.use_empty() && result.use_empty()); - if (forwarded) { - canonicalize = true; - keepMask.push_back(false); - newBlockTransferArgs.push_back(init); - newResultValues.push_back(init); - continue; - } - - // Check if a previous kept argument always has the same values for init - // and yielded values. - if (auto it = initYieldToArg.find({init, yielded}); - it != initYieldToArg.end()) { - canonicalize = true; - keepMask.push_back(false); - auto [sameArg, sameResult] = it->second; - rewriter.replaceAllUsesWith(arg, sameArg); - rewriter.replaceAllUsesWith(result, sameResult); - // The replacement value doesn't matter because there are no uses. - newBlockTransferArgs.push_back(init); - newResultValues.push_back(init); - continue; - } - - // This value is kept. - initYieldToArg.insert({{init, yielded}, {arg, result}}); - keepMask.push_back(true); - newIterArgs.push_back(init); - newYieldValues.push_back(yielded); - newBlockTransferArgs.push_back(Value()); // placeholder with null value - newResultValues.push_back(Value()); // placeholder with null value - } - - if (!canonicalize) - return failure(); - - scf::ForOp newForOp = - scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), - forOp.getUpperBound(), forOp.getStep(), newIterArgs, - /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp()); - newForOp->setAttrs(forOp->getAttrs()); - Block &newBlock = newForOp.getRegion().front(); - - // Replace the null placeholders with newly constructed values. - newBlockTransferArgs[0] = newBlock.getArgument(0); // iv - for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size(); - idx != e; ++idx) { - Value &blockTransferArg = newBlockTransferArgs[1 + idx]; - Value &newResultVal = newResultValues[idx]; - assert((blockTransferArg && newResultVal) || - (!blockTransferArg && !newResultVal)); - if (!blockTransferArg) { - blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx]; - newResultVal = newForOp.getResult(collapsedIdx++); - } - } - - Block &oldBlock = forOp.getRegion().front(); - assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() && - "unexpected argument size mismatch"); - - // No results case: the scf::ForOp builder already created a zero - // result terminator. Merge before this terminator and just get rid of the - // original terminator that has been merged in. - if (newIterArgs.empty()) { - auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator()); - rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs); - rewriter.eraseOp(newBlock.getTerminator()->getPrevNode()); - rewriter.replaceOp(forOp, newResultValues); - return success(); - } - - // No terminator case: merge and rewrite the merged terminator. - auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(mergedTerminator); - SmallVector<Value, 4> filteredOperands; - filteredOperands.reserve(newResultValues.size()); - for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx) - if (keepMask[idx]) - filteredOperands.push_back(mergedTerminator.getOperand(idx)); - scf::YieldOp::create(rewriter, mergedTerminator.getLoc(), - filteredOperands); - }; - - rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs); - auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator()); - cloneFilteredTerminator(mergedYieldOp); - rewriter.eraseOp(mergedYieldOp); - rewriter.replaceOp(forOp, newResultValues); - return success(); - } -}; - /// Rewriting pattern that erases loops that are known not to iterate, replaces /// single-iteration loops with their bodies, and removes empty loops that /// iterate at least once and only return values defined outside of the loop. @@ -1236,12 +1004,283 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> { } }; +/// Is a defined before b? +static bool isDefinedBefore(Value a, Value b) { + Region *aRegion = a.getParentRegion(); + Region *bRegion = b.getParentRegion(); + + if (aRegion->getParentOp()->isProperAncestor(bRegion->getParentOp())) { + return true; + } + if (aRegion == bRegion) { + Block *aBlock = a.getParentBlock(); + Block *bBlock = b.getParentBlock(); + if (aBlock != bBlock) + return false; + if (isa<BlockArgument>(a)) + return true; + if (isa<BlockArgument>(b)) + return false; + return a.getDefiningOp()->isBeforeInBlock(b.getDefiningOp()); + } + + return false; +} + +// Try to make successor inputs dead by replacing their uses with values that +// are not successor inputs. This pattern enables additional canonicalization +// opportunities for RemoveDeadValues. +struct RemoveUsesOfIdenticalValues + : public OpInterfaceRewritePattern<RegionBranchOpInterface> { + using OpInterfaceRewritePattern< + RegionBranchOpInterface>::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(RegionBranchOpInterface op, + PatternRewriter &rewriter) const override { + // TODO: ForallOp data flow is modeled incompletely. + if (isa<ForallOp>(op)) + return failure(); + + // Gather all potential successor inputs. (Other values may also be + // included, but we're not doing anything with them.) + SmallVector<Value> values; + llvm::append_range(values, op->getResults()); + for (Region &r : op->getRegions()) + llvm::append_range(values, r.getArguments()); + + bool changed = false; + for (Value value : values) { + if (value.use_empty()) + continue; + DenseSet<Value> possibleValues = + op.computePossibleValuesOfSuccessorInput(value); + if (possibleValues.size() == 1 && *possibleValues.begin() != value && + isDefinedBefore(*possibleValues.begin(), value)) { + // Value is same as another value. + rewriter.replaceAllUsesWith(value, *possibleValues.begin()); + changed = true; + } + } + return success(changed); + } +}; + +/// Pattern to remove dead values from region branch ops. +struct RemoveDeadValues + : public OpInterfaceRewritePattern<RegionBranchOpInterface> { + using OpInterfaceRewritePattern< + RegionBranchOpInterface>::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(RegionBranchOpInterface op, + PatternRewriter &rewriter) const override { + // TODO: ForallOp data flow is modeled incompletely. + if (isa<ForallOp>(op)) + return failure(); + + // Compute tied values: values that must come as a set. If you remove one, + // you must remove all. + RegionBranchSuccessorMapping operandToInputs; + op.getSuccessorOperandInputMapping(operandToInputs); + llvm::EquivalenceClasses<Value> tiedSuccessorInputs; + for (const auto &[operand, inputs] : operandToInputs) { + assert(!inputs.empty() && "expected non-empty inputs"); + Value firstInput = inputs.front(); + tiedSuccessorInputs.insert(firstInput); + for (Value nextInput : llvm::drop_begin(inputs)) + tiedSuccessorInputs.unionSets(firstInput, nextInput); + } + + // Determine which values to remove and group them by block and operation. + SmallVector<Value> valuesToRemove; + DenseMap<Block *, BitVector> blockArgsToRemove; + DenseMap<Operation *, BitVector> resultsToRemove; + for (auto it = tiedSuccessorInputs.begin(), e = tiedSuccessorInputs.end(); + it != e; ++it) { + if (!(*it)->isLeader()) + continue; + + // Value can be removed if it is dead and all other tied values are also + // dead. + bool allDead = true; + for (auto memberIt = tiedSuccessorInputs.member_begin(**it); + memberIt != tiedSuccessorInputs.member_end(); ++memberIt) { + if (!memberIt->use_empty()) { + allDead = false; + break; + } + } + if (!allDead) + continue; + + // Group values by block and operation. + for (auto memberIt = tiedSuccessorInputs.member_begin(**it); + memberIt != tiedSuccessorInputs.member_end(); ++memberIt) { + if (auto arg = dyn_cast<BlockArgument>(*memberIt)) { + BitVector &vector = + blockArgsToRemove + .try_emplace(arg.getOwner(), + arg.getOwner()->getNumArguments(), false) + .first->second; + vector.set(arg.getArgNumber()); + } else { + OpResult result = cast<OpResult>(*memberIt); + BitVector &vector = + resultsToRemove + .try_emplace(result.getDefiningOp(), + result.getDefiningOp()->getNumResults(), false) + .first->second; + vector.set(result.getResultNumber()); + } + valuesToRemove.push_back(*memberIt); + } + } + + if (valuesToRemove.empty()) + return rewriter.notifyMatchFailure(op, "no values to remove"); + + // Find operands that must be removed together with the values. + RegionBranchInverseSuccessorMapping inputsToOperands; + op.getSuccessorInputOperandMapping(inputsToOperands); + DenseMap<Operation *, llvm::BitVector> operandsToRemove; + for (Value value : valuesToRemove) { + for (OpOperand *operand : inputsToOperands[value]) { + BitVector &vector = + operandsToRemove + .try_emplace(operand->getOwner(), + operand->getOwner()->getNumOperands(), false) + .first->second; + vector.set(operand->getOperandNumber()); + } + } + + // Erase operands. + for (auto [op, operands] : operandsToRemove) { + rewriter.modifyOpInPlace(op, [&]() { op->eraseOperands(operands); }); + } + + // Erase block arguments. + for (auto [block, blockArgs] : blockArgsToRemove) { + rewriter.modifyOpInPlace(block->getParentOp(), + [&]() { block->eraseArguments(blockArgs); }); + } + + // Erase op results. + // TODO: Can we move this to RewriterBase, so we have a uniform API, + // similar to eraseArguments? + for (auto [op, resultsToErase] : resultsToRemove) { + rewriter.setInsertionPoint(op); + SmallVector<Type> newResultTypes; + for (OpResult result : op->getResults()) + if (!resultsToErase[result.getResultNumber()]) + newResultTypes.push_back(result.getType()); + OperationState state(op->getLoc(), op->getName().getStringRef(), + op->getOperands(), newResultTypes, op->getAttrs()); + for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) + state.addRegion(); + Operation *newOp = rewriter.create(state); + for (const auto &[index, region] : llvm::enumerate(op->getRegions())) { + // Move all blocks of `region` into `newRegion`. + Region &newRegion = newOp->getRegion(index); + rewriter.inlineRegionBefore(region, newRegion, newRegion.begin()); + } + + SmallVector<Value> newResults; + unsigned nextLiveResult = 0; + for (auto [index, result] : llvm::enumerate(op->getResults())) { + if (!resultsToErase[index]) { + newResults.push_back(newOp->getResult(nextLiveResult++)); + } else { + newResults.push_back(Value()); + } + } + rewriter.replaceOp(op, newResults); + } + + return success(); + } +}; + +void *getContainerOwnerOfValue(Value value) { + if (auto opResult = llvm::dyn_cast<OpResult>(value)) + return opResult.getDefiningOp(); + return llvm::cast<BlockArgument>(value).getOwner(); +} + +unsigned getArgOrResultNumber(Value value) { + if (auto opResult = llvm::dyn_cast<OpResult>(value)) + return opResult.getResultNumber(); + return llvm::cast<BlockArgument>(value).getArgNumber(); +} + +/// Pattern to make duplicate successor inputs dead. Two successor inputs are +/// duplicate if their corresponding successor operands have the same values. +struct RemoveDuplicateSuccessorInputUses + : public OpInterfaceRewritePattern<RegionBranchOpInterface> { + using OpInterfaceRewritePattern< + RegionBranchOpInterface>::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(RegionBranchOpInterface op, + PatternRewriter &rewriter) const override { + // TODO: ForallOp data flow is modeled incompletely. + if (isa<ForallOp>(op)) + return failure(); + + // Collect all successor inputs and sort them. When dropping the uses of a + // successor input, we'd like to also drop the uses of the same tied + // successor inputs. Otherwise, a set of tied successor inputs may not + // become entirely dead, which is required for RemoveDeadValues to erase + // them. (Sorting is not required for correctness.) + RegionBranchInverseSuccessorMapping inputsToOperands; + op.getSuccessorInputOperandMapping(inputsToOperands); + SmallVector<Value> inputs = llvm::to_vector(inputsToOperands.keys()); + llvm::sort(inputs, [](Value a, Value b) { + return getArgOrResultNumber(a) < getArgOrResultNumber(b); + }); + + bool changed = false; + for (unsigned i = 0, e = inputs.size(); i < e; i++) { + Value input1 = inputs[i]; + for (unsigned j = i + 1; j < e; j++) { + Value input2 = inputs[j]; + // Nothing to do if input2 is already dead. + if (input2.use_empty()) + continue; + // Replace only values of the same kind. + if (isa<BlockArgument>(input1) != isa<BlockArgument>(input2)) + continue; + // Replace only values that belong to the same block / operation. + if (getContainerOwnerOfValue(input1) != + getContainerOwnerOfValue(input2)) + continue; + + // Gather the predecessor value for each predecessor (region branch + // point). The two inputs are duplicates if each predecessor forwards + // the same value. + DenseMap<Operation *, Value> operands1, operands2; + for (OpOperand *operand : inputsToOperands[input1]) { + assert(!operands1.contains(operand->getOwner())); + operands1[operand->getOwner()] = operand->get(); + } + for (OpOperand *operand : inputsToOperands[input2]) { + assert(!operands2.contains(operand->getOwner())); + operands2[operand->getOwner()] = operand->get(); + } + if (operands1 == operands2) { + rewriter.replaceAllUsesWith(input2, input1); + changed = true; + } + } + } + return success(changed); + } +}; } // namespace void ForOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>( - context); + results.add<SimplifyTrivialLoops, ForOpTensorCastFolder, + RemoveUsesOfIdenticalValues, RemoveDeadValues, + RemoveDuplicateSuccessorInputUses>(context); } std::optional<APInt> ForOp::getConstantStep() { @@ -2409,61 +2448,6 @@ void IfOp::getRegionInvocationBounds( } namespace { -// Pattern to remove unused IfOp results. -struct RemoveUnusedResults : public OpRewritePattern<IfOp> { - using OpRewritePattern<IfOp>::OpRewritePattern; - - void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults, - PatternRewriter &rewriter) const { - // Move all operations to the destination block. - rewriter.mergeBlocks(source, dest); - // Replace the yield op by one that returns only the used values. - auto yieldOp = cast<scf::YieldOp>(dest->getTerminator()); - SmallVector<Value, 4> usedOperands; - llvm::transform(usedResults, std::back_inserter(usedOperands), - [&](OpResult result) { - return yieldOp.getOperand(result.getResultNumber()); - }); - rewriter.modifyOpInPlace(yieldOp, - [&]() { yieldOp->setOperands(usedOperands); }); - } - - LogicalResult matchAndRewrite(IfOp op, - PatternRewriter &rewriter) const override { - // Compute the list of used results. - SmallVector<OpResult, 4> usedResults; - llvm::copy_if(op.getResults(), std::back_inserter(usedResults), - [](OpResult result) { return !result.use_empty(); }); - - // Replace the operation if only a subset of its results have uses. - if (usedResults.size() == op.getNumResults()) - return failure(); - - // Compute the result types of the replacement operation. - SmallVector<Type, 4> newTypes; - llvm::transform(usedResults, std::back_inserter(newTypes), - [](OpResult result) { return result.getType(); }); - - // Create a replacement operation with empty then and else regions. - auto newOp = - IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition()); - rewriter.createBlock(&newOp.getThenRegion()); - rewriter.createBlock(&newOp.getElseRegion()); - - // Move the bodies and replace the terminators (note there is a then and - // an else region since the operation returns results). - transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter); - transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter); - - // Replace the operation by the new one. - SmallVector<Value, 4> repResults(op.getNumResults()); - for (const auto &en : llvm::enumerate(usedResults)) - repResults[en.value().getResultNumber()] = newOp.getResult(en.index()); - rewriter.replaceOp(op, repResults); - return success(); - } -}; - struct RemoveStaticCondition : public OpRewritePattern<IfOp> { using OpRewritePattern<IfOp>::OpRewritePattern; @@ -3034,8 +3018,8 @@ void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add<CombineIfs, CombineNestedIfs, ConditionPropagation, ConvertTrivialIfToSelect, RemoveEmptyElseBranch, - RemoveStaticCondition, RemoveUnusedResults, - ReplaceIfYieldWithConditionOrValue>(context); + RemoveStaticCondition, ReplaceIfYieldWithConditionOrValue>( + context); } Block *IfOp::thenBlock() { return &getThenRegion().back(); } @@ -3873,390 +3857,6 @@ struct WhileConditionTruth : public OpRewritePattern<WhileOp> { } }; -/// Remove loop invariant arguments from `before` block of scf.while. -/// A before block argument is considered loop invariant if :- -/// 1. i-th yield operand is equal to the i-th while operand. -/// 2. i-th yield operand is k-th after block argument which is (k+1)-th -/// condition operand AND this (k+1)-th condition operand is equal to i-th -/// iter argument/while operand. -/// For the arguments which are removed, their uses inside scf.while -/// are replaced with their corresponding initial value. -/// -/// Eg: -/// INPUT :- -/// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b, -/// ..., %argN_before = %N) -/// { -/// ... -/// scf.condition(%cond) %arg1_before, %arg0_before, -/// %arg2_before, %arg0_before, ... -/// } do { -/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2, -/// ..., %argK_after): -/// ... -/// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN -/// } -/// -/// OUTPUT :- -/// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before = -/// %N) -/// { -/// ... -/// scf.condition(%cond) %b, %a, %arg2_before, %a, ... -/// } do { -/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2, -/// ..., %argK_after): -/// ... -/// scf.yield %arg1_after, ..., %argN -/// } -/// -/// EXPLANATION: -/// We iterate over each yield operand. -/// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand -/// %arg0_before, which in turn is the 0-th iter argument. So we -/// remove 0-th before block argument and yield operand, and replace -/// all uses of the 0-th before block argument with its initial value -/// %a. -/// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial -/// value. So we remove this operand and the corresponding before -/// block argument and replace all uses of 1-th before block argument -/// with %b. -struct RemoveLoopInvariantArgsFromBeforeBlock - : public OpRewritePattern<WhileOp> { - using OpRewritePattern<WhileOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(WhileOp op, - PatternRewriter &rewriter) const override { - Block &afterBlock = *op.getAfterBody(); - Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments(); - ConditionOp condOp = op.getConditionOp(); - OperandRange condOpArgs = condOp.getArgs(); - Operation *yieldOp = afterBlock.getTerminator(); - ValueRange yieldOpArgs = yieldOp->getOperands(); - - bool canSimplify = false; - for (const auto &it : - llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) { - auto index = static_cast<unsigned>(it.index()); - auto [initVal, yieldOpArg] = it.value(); - // If i-th yield operand is equal to the i-th operand of the scf.while, - // the i-th before block argument is a loop invariant. - if (yieldOpArg == initVal) { - canSimplify = true; - break; - } - // If the i-th yield operand is k-th after block argument, then we check - // if the (k+1)-th condition op operand is equal to either the i-th before - // block argument or the initial value of i-th before block argument. If - // the comparison results `true`, i-th before block argument is a loop - // invariant. - auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg); - if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { - Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; - if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { - canSimplify = true; - break; - } - } - } - - if (!canSimplify) - return failure(); - - SmallVector<Value> newInitArgs, newYieldOpArgs; - DenseMap<unsigned, Value> beforeBlockInitValMap; - SmallVector<Location> newBeforeBlockArgLocs; - for (const auto &it : - llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) { - auto index = static_cast<unsigned>(it.index()); - auto [initVal, yieldOpArg] = it.value(); - - // If i-th yield operand is equal to the i-th operand of the scf.while, - // the i-th before block argument is a loop invariant. - if (yieldOpArg == initVal) { - beforeBlockInitValMap.insert({index, initVal}); - continue; - } else { - // If the i-th yield operand is k-th after block argument, then we check - // if the (k+1)-th condition op operand is equal to either the i-th - // before block argument or the initial value of i-th before block - // argument. If the comparison results `true`, i-th before block - // argument is a loop invariant. - auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg); - if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { - Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; - if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { - beforeBlockInitValMap.insert({index, initVal}); - continue; - } - } - } - newInitArgs.emplace_back(initVal); - newYieldOpArgs.emplace_back(yieldOpArg); - newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc()); - } - - { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(yieldOp); - rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs); - } - - auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(), - newInitArgs); - - Block &newBeforeBlock = *rewriter.createBlock( - &newWhile.getBefore(), /*insertPt*/ {}, - ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs); - - Block &beforeBlock = *op.getBeforeBody(); - SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments()); - // For each i-th before block argument we find it's replacement value as :- - // 1. If i-th before block argument is a loop invariant, we fetch it's - // initial value from `beforeBlockInitValMap` by querying for key `i`. - // 2. Else we fetch j-th new before block argument as the replacement - // value of i-th before block argument. - for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) { - // If the index 'i' argument was a loop invariant we fetch it's initial - // value from `beforeBlockInitValMap`. - if (beforeBlockInitValMap.count(i) != 0) - newBeforeBlockArgs[i] = beforeBlockInitValMap[i]; - else - newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++); - } - - rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs); - rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(), - newWhile.getAfter().begin()); - - rewriter.replaceOp(op, newWhile.getResults()); - return success(); - } -}; - -/// Remove loop invariant value from result (condition op) of scf.while. -/// A value is considered loop invariant if the final value yielded by -/// scf.condition is defined outside of the `before` block. We remove the -/// corresponding argument in `after` block and replace the use with the value. -/// We also replace the use of the corresponding result of scf.while with the -/// value. -/// -/// Eg: -/// INPUT :- -/// %res_input:K = scf.while <...> iter_args(%arg0_before = , ..., -/// %argN_before = %N) { -/// ... -/// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ... -/// } do { -/// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after): -/// ... -/// some_func(%arg1_after) -/// ... -/// scf.yield %arg0_after, %arg2_after, ..., %argN_after -/// } -/// -/// OUTPUT :- -/// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) { -/// ... -/// scf.condition(%cond) %arg0, %arg1, ..., %argM -/// } do { -/// ^bb0(%arg0, %arg3, ..., %argM): -/// ... -/// some_func(%a) -/// ... -/// scf.yield %arg0, %b, ..., %argN -/// } -/// -/// EXPLANATION: -/// 1. The 1-th and 2-th operand of scf.condition are defined outside the -/// before block of scf.while, so they get removed. -/// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are -/// replaced by %b. -/// 3. The corresponding after block argument %arg1_after's uses are -/// replaced by %a and %arg2_after's uses are replaced by %b. -struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> { - using OpRewritePattern<WhileOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(WhileOp op, - PatternRewriter &rewriter) const override { - Block &beforeBlock = *op.getBeforeBody(); - ConditionOp condOp = op.getConditionOp(); - OperandRange condOpArgs = condOp.getArgs(); - - bool canSimplify = false; - for (Value condOpArg : condOpArgs) { - // Those values not defined within `before` block will be considered as - // loop invariant values. We map the corresponding `index` with their - // value. - if (condOpArg.getParentBlock() != &beforeBlock) { - canSimplify = true; - break; - } - } - - if (!canSimplify) - return failure(); - - Block::BlockArgListType afterBlockArgs = op.getAfterArguments(); - - SmallVector<Value> newCondOpArgs; - SmallVector<Type> newAfterBlockType; - DenseMap<unsigned, Value> condOpInitValMap; - SmallVector<Location> newAfterBlockArgLocs; - for (const auto &it : llvm::enumerate(condOpArgs)) { - auto index = static_cast<unsigned>(it.index()); - Value condOpArg = it.value(); - // Those values not defined within `before` block will be considered as - // loop invariant values. We map the corresponding `index` with their - // value. - if (condOpArg.getParentBlock() != &beforeBlock) { - condOpInitValMap.insert({index, condOpArg}); - } else { - newCondOpArgs.emplace_back(condOpArg); - newAfterBlockType.emplace_back(condOpArg.getType()); - newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc()); - } - } - - { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(condOp); - rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(), - newCondOpArgs); - } - - auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType, - op.getOperands()); - - Block &newAfterBlock = - *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {}, - newAfterBlockType, newAfterBlockArgLocs); - - Block &afterBlock = *op.getAfterBody(); - // Since a new scf.condition op was created, we need to fetch the new - // `after` block arguments which will be used while replacing operations of - // previous scf.while's `after` blocks. We'd also be fetching new result - // values too. - SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments()); - SmallVector<Value> newWhileResults(afterBlock.getNumArguments()); - for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) { - Value afterBlockArg, result; - // If index 'i' argument was loop invariant we fetch it's value from the - // `condOpInitMap` map. - if (condOpInitValMap.count(i) != 0) { - afterBlockArg = condOpInitValMap[i]; - result = afterBlockArg; - } else { - afterBlockArg = newAfterBlock.getArgument(j); - result = newWhile.getResult(j); - j++; - } - newAfterBlockArgs[i] = afterBlockArg; - newWhileResults[i] = result; - } - - rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs); - rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(), - newWhile.getBefore().begin()); - - rewriter.replaceOp(op, newWhileResults); - return success(); - } -}; - -/// Remove WhileOp results that are also unused in 'after' block. -/// -/// %0:2 = scf.while () : () -> (i32, i64) { -/// %condition = "test.condition"() : () -> i1 -/// %v1 = "test.get_some_value"() : () -> i32 -/// %v2 = "test.get_some_value"() : () -> i64 -/// scf.condition(%condition) %v1, %v2 : i32, i64 -/// } do { -/// ^bb0(%arg0: i32, %arg1: i64): -/// "test.use"(%arg0) : (i32) -> () -/// scf.yield -/// } -/// return %0#0 : i32 -/// -/// becomes -/// %0 = scf.while () : () -> (i32) { -/// %condition = "test.condition"() : () -> i1 -/// %v1 = "test.get_some_value"() : () -> i32 -/// %v2 = "test.get_some_value"() : () -> i64 -/// scf.condition(%condition) %v1 : i32 -/// } do { -/// ^bb0(%arg0: i32): -/// "test.use"(%arg0) : (i32) -> () -/// scf.yield -/// } -/// return %0 : i32 -struct WhileUnusedResult : public OpRewritePattern<WhileOp> { - using OpRewritePattern<WhileOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(WhileOp op, - PatternRewriter &rewriter) const override { - auto term = op.getConditionOp(); - auto afterArgs = op.getAfterArguments(); - auto termArgs = term.getArgs(); - - // Collect results mapping, new terminator args and new result types. - SmallVector<unsigned> newResultsIndices; - SmallVector<Type> newResultTypes; - SmallVector<Value> newTermArgs; - SmallVector<Location> newArgLocs; - bool needUpdate = false; - for (const auto &it : - llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) { - auto i = static_cast<unsigned>(it.index()); - Value result = std::get<0>(it.value()); - Value afterArg = std::get<1>(it.value()); - Value termArg = std::get<2>(it.value()); - if (result.use_empty() && afterArg.use_empty()) { - needUpdate = true; - } else { - newResultsIndices.emplace_back(i); - newTermArgs.emplace_back(termArg); - newResultTypes.emplace_back(result.getType()); - newArgLocs.emplace_back(result.getLoc()); - } - } - - if (!needUpdate) - return failure(); - - { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(term); - rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(), - newTermArgs); - } - - auto newWhile = - WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits()); - - Block &newAfterBlock = *rewriter.createBlock( - &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs); - - // Build new results list and new after block args (unused entries will be - // null). - SmallVector<Value> newResults(op.getNumResults()); - SmallVector<Value> newAfterBlockArgs(op.getNumResults()); - for (const auto &it : llvm::enumerate(newResultsIndices)) { - newResults[it.value()] = newWhile.getResult(it.index()); - newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index()); - } - - rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(), - newWhile.getBefore().begin()); - - Block &afterBlock = *op.getAfterBody(); - rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs); - - rewriter.replaceOp(op, newResults); - return success(); - } -}; - /// Replace operations equivalent to the condition in the do block with true, /// since otherwise the block would not be evaluated. /// @@ -4321,127 +3921,6 @@ struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> { } }; -/// Remove unused init/yield args. -struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> { - using OpRewritePattern<WhileOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(WhileOp op, - PatternRewriter &rewriter) const override { - - if (!llvm::any_of(op.getBeforeArguments(), - [](Value arg) { return arg.use_empty(); })) - return rewriter.notifyMatchFailure(op, "No args to remove"); - - YieldOp yield = op.getYieldOp(); - - // Collect results mapping, new terminator args and new result types. - SmallVector<Value> newYields; - SmallVector<Value> newInits; - llvm::BitVector argsToErase; - - size_t argsCount = op.getBeforeArguments().size(); - newYields.reserve(argsCount); - newInits.reserve(argsCount); - argsToErase.reserve(argsCount); - for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip( - op.getBeforeArguments(), yield.getOperands(), op.getInits())) { - if (beforeArg.use_empty()) { - argsToErase.push_back(true); - } else { - argsToErase.push_back(false); - newYields.emplace_back(yieldValue); - newInits.emplace_back(initValue); - } - } - - Block &beforeBlock = *op.getBeforeBody(); - Block &afterBlock = *op.getAfterBody(); - - beforeBlock.eraseArguments(argsToErase); - - Location loc = op.getLoc(); - auto newWhileOp = - WhileOp::create(rewriter, loc, op.getResultTypes(), newInits, - /*beforeBody*/ nullptr, /*afterBody*/ nullptr); - Block &newBeforeBlock = *newWhileOp.getBeforeBody(); - Block &newAfterBlock = *newWhileOp.getAfterBody(); - - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(yield); - rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields); - - rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, - newBeforeBlock.getArguments()); - rewriter.mergeBlocks(&afterBlock, &newAfterBlock, - newAfterBlock.getArguments()); - - rewriter.replaceOp(op, newWhileOp.getResults()); - return success(); - } -}; - -/// Remove duplicated ConditionOp args. -struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(WhileOp op, - PatternRewriter &rewriter) const override { - ConditionOp condOp = op.getConditionOp(); - ValueRange condOpArgs = condOp.getArgs(); - - llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs); - - if (argsSet.size() == condOpArgs.size()) - return rewriter.notifyMatchFailure(op, "No results to remove"); - - llvm::SmallDenseMap<Value, unsigned> argsMap; - SmallVector<Value> newArgs; - argsMap.reserve(condOpArgs.size()); - newArgs.reserve(condOpArgs.size()); - for (Value arg : condOpArgs) { - if (!argsMap.count(arg)) { - auto pos = static_cast<unsigned>(argsMap.size()); - argsMap.insert({arg, pos}); - newArgs.emplace_back(arg); - } - } - - ValueRange argsRange(newArgs); - - Location loc = op.getLoc(); - auto newWhileOp = - scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(), - /*beforeBody*/ nullptr, - /*afterBody*/ nullptr); - Block &newBeforeBlock = *newWhileOp.getBeforeBody(); - Block &newAfterBlock = *newWhileOp.getAfterBody(); - - SmallVector<Value> afterArgsMapping; - SmallVector<Value> resultsMapping; - for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) { - auto it = argsMap.find(arg); - assert(it != argsMap.end()); - auto pos = it->second; - afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos)); - resultsMapping.emplace_back(newWhileOp->getResult(pos)); - } - - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(condOp); - rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(), - argsRange); - - Block &beforeBlock = *op.getBeforeBody(); - Block &afterBlock = *op.getAfterBody(); - - rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, - newBeforeBlock.getArguments()); - rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping); - rewriter.replaceOp(op, resultsMapping); - return success(); - } -}; - /// If both ranges contain same values return mappping indices from args2 to /// args1. Otherwise return std::nullopt. static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1, @@ -4532,11 +4011,8 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> { void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<RemoveLoopInvariantArgsFromBeforeBlock, - RemoveLoopInvariantValueYielded, WhileConditionTruth, - WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults, - WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>( - context); + results.add<WhileConditionTruth, WhileCmpCond, WhileOpAlignBeforeArgs, + WhileMoveIfDown>(context); } //===----------------------------------------------------------------------===// @@ -4711,59 +4187,9 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> { } }; -/// Canonicalization patterns that folds away dead results of -/// "scf.index_switch" ops. -struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> { - using OpRewritePattern<IndexSwitchOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(IndexSwitchOp op, - PatternRewriter &rewriter) const override { - // Find dead results. - BitVector deadResults(op.getNumResults(), false); - SmallVector<Type> newResultTypes; - for (auto [idx, result] : llvm::enumerate(op.getResults())) { - if (!result.use_empty()) { - newResultTypes.push_back(result.getType()); - } else { - deadResults[idx] = true; - } - } - if (!deadResults.any()) - return rewriter.notifyMatchFailure(op, "no dead results to fold"); - - // Create new op without dead results and inline case regions. - auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes, - op.getArg(), op.getCases(), - op.getCaseRegions().size()); - auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) { - rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin()); - // Remove respective operands from yield op. - Operation *terminator = newRegion.front().getTerminator(); - assert(isa<YieldOp>(terminator) && "expected yield op"); - rewriter.modifyOpInPlace( - terminator, [&]() { terminator->eraseOperands(deadResults); }); - }; - for (auto [oldRegion, newRegion] : - llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions())) - inlineCaseRegion(oldRegion, newRegion); - inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion()); - - // Replace op with new op. - SmallVector<Value> newResults(op.getNumResults(), Value()); - unsigned nextNewResult = 0; - for (unsigned idx = 0; idx < op.getNumResults(); ++idx) { - if (deadResults[idx]) - continue; - newResults[idx] = newOp.getResult(nextNewResult++); - } - rewriter.replaceOp(op, newResults); - return success(); - } -}; - void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context); + results.add<FoldConstantCase>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp index d393ddb8d8336..ed94205d32f19 100644 --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -521,6 +521,45 @@ void RegionBranchOpInterface::getSuccessorOperandInputMapping( } } +void RegionBranchOpInterface::getSuccessorInputOperandMapping( + RegionBranchInverseSuccessorMapping &mapping) { + RegionBranchSuccessorMapping operandToInputs; + getSuccessorOperandInputMapping(operandToInputs); + for (const auto &[operand, inputs] : operandToInputs) { + for (Value input : inputs) + mapping[input].push_back(operand); + } +} + +DenseSet<Value> +RegionBranchOpInterface::computePossibleValuesOfSuccessorInput(Value value) { + RegionBranchInverseSuccessorMapping inputToOperands; + getSuccessorInputOperandMapping(inputToOperands); + + DenseSet<Value> possibleValues; + DenseSet<Value> visited; + SmallVector<Value> worklist; + + // Starting with the given value, trace back all predecessor values (i.e., + // preceding successor operands) and add them to the set of possible values. + // If the successor operand is again a successor input, do not add it to + // result set, but instead continue the traversal. + worklist.push_back(value); + while (!worklist.empty()) { + Value next = worklist.pop_back_val(); + auto it = inputToOperands.find(next); + if (it == inputToOperands.end()) { + possibleValues.insert(next); + continue; + } + for (OpOperand *operand : it->second) + if (visited.insert(operand->get()).second) + worklist.push_back(operand->get()); + } + + return possibleValues; +} + SmallVector<RegionBranchPoint> RegionBranchOpInterface::getAllRegionBranchPoints() { SmallVector<RegionBranchPoint> branchPoints; diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index d5d0aee3bbe25..11dc4f04af32e 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1071,17 +1071,17 @@ func.func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i3 // CHECK: %[[ZERO:.*]] = arith.constant dense<0> // CHECK: %[[ONE:.*]] = arith.constant dense<1> // CHECK: %[[CST42:.*]] = arith.constant dense<42> -// CHECK: %[[WHILE:.*]]:3 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]], %[[ARG3:.*]] = %[[ONE]]) +// CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]]) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) // CHECK: arith.cmpi slt, %[[ARG0]], %{{.*}} // CHECK: tensor.extract %{{.*}}[] -// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG2]], %[[ARG3]] +// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG2]] // CHECK: } do { -// CHECK: ^{{.*}}(%[[ARG0:.*]]: tensor<i32>, %[[ARG2:.*]]: tensor<i32>, %[[ARG3:.*]]: tensor<i32>): +// CHECK: ^{{.*}}(%[[ARG0:.*]]: tensor<i32>, %[[ARG2:.*]]: tensor<i32>): // CHECK: %[[VAL0:.*]] = arith.addi %[[ARG0]], %[[FUNC_ARG0]] -// CHECK: %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG3]] -// CHECK: scf.yield %[[VAL0]], %[[VAL1]], %[[VAL1]] +// CHECK: %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG2]] +// CHECK: scf.yield %[[VAL0]], %[[VAL1]] // CHECK: } -// CHECK: return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#2, %[[ZERO]] +// CHECK: return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#1, %[[ZERO]] // CHECK-LABEL: @while_loop_invariant_argument_different_order func.func @while_loop_invariant_argument_different_order(%arg : tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) { @@ -1736,11 +1736,11 @@ module { // Test case with multiple scf.yield ops with at least one different operand, then no change. -// CHECK: %[[VAL_3:.*]]:2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline { +// CHECK: %[[VAL_3:.*]] = scf.execute_region -> memref<1x120xui8> no_inline { // CHECK: ^bb1: -// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8> +// CHECK: scf.yield %{{.*}} : memref<1x120xui8> // CHECK: ^bb2: -// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8> +// CHECK: scf.yield %{{.*}} : memref<1x120xui8> // CHECK: } module { @@ -2178,16 +2178,14 @@ func.func @scf_for_all_step_size_0() { // CHECK-SAME: %[[arg0:.*]]: index // CHECK-DAG: %[[c10:.*]] = arith.constant 10 // CHECK-DAG: %[[c11:.*]] = arith.constant 11 -// CHECK: %[[switch:.*]] = scf.index_switch %[[arg0]] -> index +// CHECK: scf.index_switch %[[arg0]] // CHECK: case 1 { // CHECK: memref.store %[[c10]] -// CHECK: scf.yield %[[arg0]] : index // CHECK: } // CHECK: default { // CHECK: memref.store %[[c11]] -// CHECK: scf.yield %[[arg0]] : index // CHECK: } -// CHECK: return %[[switch]] +// CHECK: return %[[arg0]] func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index { %non_live, %live = scf.index_switch %arg0 -> i32, index case 1 { diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir index 5f2aa5e3a2736..b0a1af31d8806 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir @@ -129,13 +129,13 @@ func.func @matmul_sparse_rhs(%a: tensor<10x20xf32>, // CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_28]]] : memref<?xindex> // CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_2]]] : memref<?xindex> // CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_3]]] : memref<?xindex> -// CHECK: %[[VAL_32:.*]]:4 = scf.while (%[[VAL_33:.*]] = %[[VAL_27]], %[[VAL_34:.*]] = %[[VAL_30]], %[[VAL_35:.*]] = %[[VAL_26]], %[[VAL_36:.*]] = %[[VAL_21]]) : (index, index, index, tensor<4x4xf64, #sparse{{[0-9]*}}>) -> (index, index, index, tensor<4x4xf64, #sparse{{[0-9]*}}>) { +// CHECK: %[[VAL_32:.*]]:3 = scf.while (%[[VAL_33:.*]] = %[[VAL_27]], %[[VAL_34:.*]] = %[[VAL_30]], %[[VAL_35:.*]] = %[[VAL_26]]) : (index, index, index) -> (index, index, index) { // CHECK: %[[VAL_37:.*]] = arith.cmpi ult, %[[VAL_33]], %[[VAL_29]] : index // CHECK: %[[VAL_38:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_31]] : index // CHECK: %[[VAL_39:.*]] = arith.andi %[[VAL_37]], %[[VAL_38]] : i1 -// CHECK: scf.condition(%[[VAL_39]]) %[[VAL_33]], %[[VAL_34]], %[[VAL_35]], %[[VAL_36]] : index, index, index, tensor<4x4xf64, #sparse{{[0-9]*}}> +// CHECK: scf.condition(%[[VAL_39]]) %[[VAL_33]], %[[VAL_34]], %[[VAL_35]] : index, index, index // CHECK: } do { -// CHECK: ^bb0(%[[VAL_40:.*]]: index, %[[VAL_41:.*]]: index, %[[VAL_42:.*]]: index, %[[VAL_43:.*]]: tensor<4x4xf64, #sparse{{[0-9]*}}>): +// CHECK: ^bb0(%[[VAL_40:.*]]: index, %[[VAL_41:.*]]: index, %[[VAL_42:.*]]: index): // CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_40]]] : memref<?xindex> // CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_41]]] : memref<?xindex> // CHECK: %[[VAL_46:.*]] = arith.cmpi ult, %[[VAL_45]], %[[VAL_44]] : index @@ -143,7 +143,7 @@ func.func @matmul_sparse_rhs(%a: tensor<10x20xf32>, // CHECK: %[[VAL_48:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_47]] : index // CHECK: %[[VAL_49:.*]] = arith.cmpi eq, %[[VAL_45]], %[[VAL_47]] : index // CHECK: %[[VAL_50:.*]] = arith.andi %[[VAL_48]], %[[VAL_49]] : i1 -// CHECK: %[[VAL_51:.*]]:2 = scf.if %[[VAL_50]] -> (index, tensor<4x4xf64, #sparse{{[0-9]*}}>) { +// CHECK: %[[VAL_51:.*]] = scf.if %[[VAL_50]] -> (index) { // CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_40]]] : memref<?xf64> // CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_41]]] : memref<?xindex> // CHECK: %[[VAL_54:.*]] = arith.addi %[[VAL_41]], %[[VAL_3]] : index @@ -167,9 +167,9 @@ func.func @matmul_sparse_rhs(%a: tensor<10x20xf32>, // CHECK: memref.store %[[VAL_63]], %[[VAL_23]]{{\[}}%[[VAL_59]]] : memref<?xf64> // CHECK: scf.yield %[[VAL_68:.*]] : index // CHECK: } -// CHECK: scf.yield %[[VAL_69:.*]], %[[VAL_43]] : index, tensor<4x4xf64, #sparse{{[0-9]*}}> +// CHECK: scf.yield %[[VAL_69:.*]] : index // CHECK: } else { -// CHECK: scf.yield %[[VAL_42]], %[[VAL_43]] : index, tensor<4x4xf64, #sparse{{[0-9]*}}> +// CHECK: scf.yield %[[VAL_42]] : index // CHECK: } // CHECK: %[[VAL_70:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_47]] : index // CHECK: %[[VAL_71:.*]] = arith.addi %[[VAL_40]], %[[VAL_3]] : index @@ -177,9 +177,9 @@ func.func @matmul_sparse_rhs(%a: tensor<10x20xf32>, // CHECK: %[[VAL_73:.*]] = arith.cmpi eq, %[[VAL_45]], %[[VAL_47]] : index // CHECK: %[[VAL_74:.*]] = arith.addi %[[VAL_41]], %[[VAL_3]] : index // CHECK: %[[VAL_75:.*]] = arith.select %[[VAL_73]], %[[VAL_74]], %[[VAL_41]] : index -// CHECK: scf.yield %[[VAL_72]], %[[VAL_75]], %[[VAL_76:.*]]#0, %[[VAL_76]]#1 : index, index, index, tensor<4x4xf64, #sparse{{[0-9]*}}> +// CHECK: scf.yield %[[VAL_72]], %[[VAL_75]], %[[VAL_76:.*]] : index, index, index // CHECK: } -// CHECK: %[[VAL_77:.*]] = sparse_tensor.compress %[[VAL_23]], %[[VAL_24]], %[[VAL_25]], %[[VAL_78:.*]]#2 into %[[VAL_78]]#3{{\[}}%[[VAL_22]]] : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<4x4xf64, #sparse{{[0-9]*}}> +// CHECK: %[[VAL_77:.*]] = sparse_tensor.compress %[[VAL_23]], %[[VAL_24]], %[[VAL_25]], %[[VAL_78:.*]]#2 into %[[VAL_21]]{{\[}}%[[VAL_22]]] : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<4x4xf64, #sparse{{[0-9]*}}> // CHECK: scf.yield %[[VAL_77]] : tensor<4x4xf64, #sparse{{[0-9]*}}> // CHECK: } // CHECK: %[[VAL_79:.*]] = sparse_tensor.load %[[VAL_80:.*]] hasInserts : tensor<4x4xf64, #sparse{{[0-9]*}}> diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir index 4dff06b8155dd..67d1573058460 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir @@ -216,13 +216,13 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x // CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_58]]] : memref<?xindex> // CHECK: %[[VAL_72:.*]] = arith.addi %[[VAL_58]], %[[VAL_3]] : index // CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_72]]] : memref<?xindex> -// CHECK: %[[VAL_74:.*]]:5 = scf.while (%[[VAL_75:.*]] = %[[VAL_68]], %[[VAL_76:.*]] = %[[VAL_71]], %[[VAL_77:.*]] = %[[VAL_4]], %[[VAL_200:.*]] = %[[VAL_FALSE]], %[[VAL_78:.*]] = %[[VAL_59]]) : (index, index, i32, i1, tensor<?x?xi32, #sparse{{[0-9]*}}>) -> (index, index, i32, i1, tensor<?x?xi32, #sparse{{[0-9]*}}>) { +// CHECK: %[[VAL_74:.*]]:4 = scf.while (%[[VAL_75:.*]] = %[[VAL_68]], %[[VAL_76:.*]] = %[[VAL_71]], %[[VAL_77:.*]] = %[[VAL_4]], %[[VAL_200:.*]] = %[[VAL_FALSE]]) : (index, index, i32, i1) -> (index, index, i32, i1) { // CHECK: %[[VAL_79:.*]] = arith.cmpi ult, %[[VAL_75]], %[[VAL_70]] : index // CHECK: %[[VAL_80:.*]] = arith.cmpi ult, %[[VAL_76]], %[[VAL_73]] : index // CHECK: %[[VAL_81:.*]] = arith.andi %[[VAL_79]], %[[VAL_80]] : i1 -// CHECK: scf.condition(%[[VAL_81]]) %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_200]], %[[VAL_78]] : index, index, i32, i1, tensor<?x?xi32, #sparse{{[0-9]*}}> +// CHECK: scf.condition(%[[VAL_81]]) %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_200]] : index, index, i32, i1 // CHECK: } do { -// CHECK: ^bb0(%[[VAL_82:.*]]: index, %[[VAL_83:.*]]: index, %[[VAL_84:.*]]: i32, %[[VAL_201:.*]]: i1, %[[VAL_85:.*]]: tensor<?x?xi32, #sparse{{[0-9]*}}>): +// CHECK: ^bb0(%[[VAL_82:.*]]: index, %[[VAL_83:.*]]: index, %[[VAL_84:.*]]: i32, %[[VAL_201:.*]]: i1): // CHECK: %[[VAL_86:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_82]]] : memref<?xindex> // CHECK: %[[VAL_87:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_83]]] : memref<?xindex> // CHECK: %[[VAL_88:.*]] = arith.cmpi ult, %[[VAL_87]], %[[VAL_86]] : index @@ -230,14 +230,14 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x // CHECK: %[[VAL_90:.*]] = arith.cmpi eq, %[[VAL_86]], %[[VAL_89]] : index // CHECK: %[[VAL_91:.*]] = arith.cmpi eq, %[[VAL_87]], %[[VAL_89]] : index // CHECK: %[[VAL_92:.*]] = arith.andi %[[VAL_90]], %[[VAL_91]] : i1 -// CHECK: %[[VAL_93:.*]]:3 = scf.if %[[VAL_92]] -> (i32, i1, tensor<?x?xi32, #sparse{{[0-9]*}}>) { +// CHECK: %[[VAL_93:.*]]:2 = scf.if %[[VAL_92]] -> (i32, i1) { // CHECK: %[[VAL_94:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_82]]] : memref<?xi32> // CHECK: %[[VAL_95:.*]] = memref.load %[[VAL_21]]{{\[}}%[[VAL_83]]] : memref<?xi32> // CHECK: %[[VAL_96:.*]] = arith.muli %[[VAL_94]], %[[VAL_95]] : i32 // CHECK: %[[VAL_97:.*]] = arith.addi %[[VAL_84]], %[[VAL_96]] : i32 -// CHECK: scf.yield %[[VAL_97]], %[[VAL_TRUE]], %[[VAL_85]] : i32, i1, tensor<?x?xi32, #sparse{{[0-9]*}}> +// CHECK: scf.yield %[[VAL_97]], %[[VAL_TRUE]] : i32, i1 // CHECK: } else { -// CHECK: scf.yield %[[VAL_84]], %[[VAL_201]], %[[VAL_85]] : i32, i1, tensor<?x?xi32, #sparse{{[0-9]*}}> +// CHECK: scf.yield %[[VAL_84]], %[[VAL_201]] : i32, i1 // CHECK: } // CHECK: %[[VAL_98:.*]] = arith.cmpi eq, %[[VAL_86]], %[[VAL_89]] : index // CHECK: %[[VAL_99:.*]] = arith.addi %[[VAL_82]], %[[VAL_3]] : index @@ -245,13 +245,13 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x // CHECK: %[[VAL_101:.*]] = arith.cmpi eq, %[[VAL_87]], %[[VAL_89]] : index // CHECK: %[[VAL_102:.*]] = arith.addi %[[VAL_83]], %[[VAL_3]] : index // CHECK: %[[VAL_103:.*]] = arith.select %[[VAL_101]], %[[VAL_102]], %[[VAL_83]] : index -// CHECK: scf.yield %[[VAL_100]], %[[VAL_103]], %[[VAL_104:.*]]#0, %[[VAL_104]]#1, %[[VAL_104]]#2 : index, index, i32, i1, tensor<?x?xi32, #sparse{{[0-9]*}}> +// CHECK: scf.yield %[[VAL_100]], %[[VAL_103]], %[[VAL_104:.*]]#0, %[[VAL_104]]#1 : index, index, i32, i1 // CHECK: } // CHECK: %[[VAL_202:.*]] = scf.if %[[VAL_74]]#3 -> (tensor<?x?xi32, #sparse{{[0-9]*}}>) { -// CHECK: %[[VAL_105:.*]] = tensor.insert %[[VAL_74]]#2 into %[[VAL_74]]#4{{\[}}%[[VAL_39]], %[[VAL_63]]] : tensor<?x?xi32, #sparse{{[0-9]*}}> +// CHECK: %[[VAL_105:.*]] = tensor.insert %[[VAL_74]]#2 into %[[VAL_59]]{{\[}}%[[VAL_39]], %[[VAL_63]]] : tensor<?x?xi32, #sparse{{[0-9]*}}> // CHECK: scf.yield %[[VAL_105]] : tensor<?x?xi32, #sparse{{[0-9]*}}> // CHECK: } else { -// CHECK: scf.yield %[[VAL_74]]#4 : tensor<?x?xi32, #sparse{{[0-9]*}}> +// CHECK: scf.yield %[[VAL_59]] : tensor<?x?xi32, #sparse{{[0-9]*}}> // CHECK: } // CHECK: scf.yield %[[VAL_202]] : tensor<?x?xi32, #sparse{{[0-9]*}}> // CHECK: } else { @@ -339,13 +339,13 @@ func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>, // CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<?xindex> // CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_2]]] : memref<?xindex> // CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_3]]] : memref<?xindex> -// CHECK: %[[VAL_34:.*]]:4 = scf.while (%[[VAL_35:.*]] = %[[VAL_29]], %[[VAL_36:.*]] = %[[VAL_32]], %[[VAL_37:.*]] = %[[VAL_28]], %[[VAL_38:.*]] = %[[VAL_23]]) : (index, index, index, tensor<?x?xf32, #sparse{{[0-9]*}}>) -> (index, index, index, tensor<?x?xf32, #sparse{{[0-9]*}}>) { +// CHECK: %[[VAL_34:.*]]:3 = scf.while (%[[VAL_35:.*]] = %[[VAL_29]], %[[VAL_36:.*]] = %[[VAL_32]], %[[VAL_37:.*]] = %[[VAL_28]]) : (index, index, index) -> (index, index, index) { // CHECK: %[[VAL_39:.*]] = arith.cmpi ult, %[[VAL_35]], %[[VAL_31]] : index // CHECK: %[[VAL_40:.*]] = arith.cmpi ult, %[[VAL_36]], %[[VAL_33]] : index // CHECK: %[[VAL_41:.*]] = arith.andi %[[VAL_39]], %[[VAL_40]] : i1 -// CHECK: scf.condition(%[[VAL_41]]) %[[VAL_35]], %[[VAL_36]], %[[VAL_37]], %[[VAL_38]] : index, index, index, tensor<?x?xf32, #sparse{{[0-9]*}}> +// CHECK: scf.condition(%[[VAL_41]]) %[[VAL_35]], %[[VAL_36]], %[[VAL_37]] : index, index, index // CHECK: } do { -// CHECK: ^bb0(%[[VAL_42:.*]]: index, %[[VAL_43:.*]]: index, %[[VAL_44:.*]]: index, %[[VAL_45:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>): +// CHECK: ^bb0(%[[VAL_42:.*]]: index, %[[VAL_43:.*]]: index, %[[VAL_44:.*]]: index): // CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_42]]] : memref<?xindex> // CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_43]]] : memref<?xindex> // CHECK: %[[VAL_48:.*]] = arith.cmpi ult, %[[VAL_47]], %[[VAL_46]] : index @@ -353,7 +353,7 @@ func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>, // CHECK: %[[VAL_50:.*]] = arith.cmpi eq, %[[VAL_46]], %[[VAL_49]] : index // CHECK: %[[VAL_51:.*]] = arith.cmpi eq, %[[VAL_47]], %[[VAL_49]] : index // CHECK: %[[VAL_52:.*]] = arith.andi %[[VAL_50]], %[[VAL_51]] : i1 -// CHECK: %[[VAL_53:.*]]:2 = scf.if %[[VAL_52]] -> (index, tensor<?x?xf32, #sparse{{[0-9]*}}>) { +// CHECK: %[[VAL_53:.*]] = scf.if %[[VAL_52]] -> (index) { // CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_42]]] : memref<?xf32> // CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_43]]] : memref<?xindex> // CHECK: %[[VAL_56:.*]] = arith.addi %[[VAL_43]], %[[VAL_3]] : index @@ -377,9 +377,9 @@ func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>, // CHECK: memref.store %[[VAL_65]], %[[VAL_25]]{{\[}}%[[VAL_61]]] : memref<?xf32> // CHECK: scf.yield %[[VAL_70:.*]] : index // CHECK: } -// CHECK: scf.yield %[[VAL_71:.*]], %[[VAL_45]] : index, tensor<?x?xf32, #sparse{{[0-9]*}}> +// CHECK: scf.yield %[[VAL_71:.*]] : index // CHECK: } else { -// CHECK: scf.yield %[[VAL_44]], %[[VAL_45]] : index, tensor<?x?xf32, #sparse{{[0-9]*}}> +// CHECK: scf.yield %[[VAL_44]] : index // CHECK: } // CHECK: %[[VAL_72:.*]] = arith.cmpi eq, %[[VAL_46]], %[[VAL_49]] : index // CHECK: %[[VAL_73:.*]] = arith.addi %[[VAL_42]], %[[VAL_3]] : index @@ -387,9 +387,9 @@ func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>, // CHECK: %[[VAL_75:.*]] = arith.cmpi eq, %[[VAL_47]], %[[VAL_49]] : index // CHECK: %[[VAL_76:.*]] = arith.addi %[[VAL_43]], %[[VAL_3]] : index // CHECK: %[[VAL_77:.*]] = arith.select %[[VAL_75]], %[[VAL_76]], %[[VAL_43]] : index -// CHECK: scf.yield %[[VAL_74]], %[[VAL_77]], %[[VAL_78:.*]]#0, %[[VAL_78]]#1 : index, index, index, tensor<?x?xf32, #sparse{{[0-9]*}}> +// CHECK: scf.yield %[[VAL_74]], %[[VAL_77]], %[[VAL_78:.*]] : index, index, index // CHECK: } -// CHECK: %[[VAL_79:.*]] = sparse_tensor.compress %[[VAL_25]], %[[VAL_26]], %[[VAL_27]], %[[VAL_80:.*]]#2 into %[[VAL_80]]#3{{\[}}%[[VAL_24]]] : memref<?xf32>, memref<?xi1>, memref<?xindex>, tensor<?x?xf32, #sparse{{[0-9]*}}> +// CHECK: %[[VAL_79:.*]] = sparse_tensor.compress %[[VAL_25]], %[[VAL_26]], %[[VAL_27]], %[[VAL_80:.*]]#2 into %[[VAL_23]]{{\[}}%[[VAL_24]]] : memref<?xf32>, memref<?xi1>, memref<?xindex>, tensor<?x?xf32, #sparse{{[0-9]*}}> // CHECK: scf.yield %[[VAL_79]] : tensor<?x?xf32, #sparse{{[0-9]*}}> // CHECK: } // CHECK: %[[VAL_81:.*]] = sparse_tensor.load %[[VAL_82:.*]] hasInserts : tensor<?x?xf32, #sparse{{[0-9]*}}> diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 135db02d543ef..18fb6852f6875 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1330,11 +1330,11 @@ func.func @vector_insert_1d_broadcast(%laneid: index, %pos: index) -> (vector<96 // ----- // CHECK-PROP-LABEL: func @vector_insert_0d( -// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> (vector<f32>, f32) +// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (f32) // CHECK-PROP: %[[VEC:.*]] = "some_def" // CHECK-PROP: %[[VAL:.*]] = "another_def" -// CHECK-PROP: gpu.yield %[[VEC]], %[[VAL]] -// CHECK-PROP: vector.broadcast %[[W]]#1 : f32 to vector<f32> +// CHECK-PROP: gpu.yield %[[VAL]] +// CHECK-PROP: vector.broadcast %[[W]] : f32 to vector<f32> func.func @vector_insert_0d(%laneid: index) -> (vector<f32>) { %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<f32>) { %0 = "some_def"() : () -> (vector<f32>) diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index b9a883dbd524e..5bf5487974d35 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -250,13 +250,13 @@ func.func @main() -> (i32, i32) { // CHECK-NEXT: } // CHECK-CANONICALIZE: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 { -// CHECK-CANONICALIZE-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) { +// CHECK-CANONICALIZE: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) { // CHECK-CANONICALIZE-NEXT: %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]] -// CHECK-CANONICALIZE-NEXT: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32 +// CHECK-CANONICALIZE: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32 // CHECK-CANONICALIZE-NEXT: } do { // CHECK-CANONICALIZE-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32): // CHECK-CANONICALIZE-NEXT: %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]] -// CHECK-CANONICALIZE-NEXT: scf.yield %[[live_1]] : i32 +// CHECK-CANONICALIZE: scf.yield %[[live_1]] : i32 // CHECK-CANONICALIZE-NEXT: } // CHECK-CANONICALIZE-NEXT: return %[[live_and_non_live]]#0 // CHECK-CANONICALIZE-NEXT: } @@ -306,7 +306,7 @@ func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_o // CHECK-CANONICALIZE: func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 { // CHECK-CANONICALIZE-NEXT: %[[c0:.*]] = arith.constant 0 // CHECK-CANONICALIZE-NEXT: %[[c1:.*]] = arith.constant 1 -// CHECK-CANONICALIZE-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) { +// CHECK-CANONICALIZE: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) { // CHECK-CANONICALIZE-NEXT: func.call @identity() : () -> () // CHECK-CANONICALIZE-NEXT: scf.condition(%[[arg2]]) %[[arg3]], %[[arg4]] : i32, i32 // CHECK-CANONICALIZE-NEXT: } do { _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
