================
@@ -583,3 +624,283 @@ Region *mlir::getEnclosingRepetitiveRegion(Value value) {
LDBG() << "No enclosing repetitive region found for value";
return nullptr;
}
+
+/// Is a defined before b?
+static bool isDefinedBefore(Value a, Value b) {
+ Region *aRegion = a.getParentRegion();
+ Region *bRegion = b.getParentRegion();
+
+ if (aRegion->getParentOp()->isProperAncestor(bRegion->getParentOp())) {
+ return true;
+ }
+ if (aRegion == bRegion) {
+ Block *aBlock = a.getParentBlock();
+ Block *bBlock = b.getParentBlock();
+ if (aBlock != bBlock)
+ return false;
+ if (isa<BlockArgument>(a))
+ return true;
+ if (isa<BlockArgument>(b))
+ return false;
+ return a.getDefiningOp()->isBeforeInBlock(b.getDefiningOp());
+ }
+
+ return false;
+}
+
+namespace {
+// Try to make successor inputs dead by replacing their uses with values that
+// are not successor inputs. This pattern enables additional canonicalization
+// opportunities for RemoveDeadValues.
+struct RemoveUsesOfIdenticalValues : public RewritePattern {
+ RemoveUsesOfIdenticalValues(MLIRContext *context, StringRef name,
+ PatternBenefit benefit = 1)
+ : RewritePattern(name, benefit, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+
+ // Gather all potential successor inputs. (Other values may also be
+ // included, but we're not doing anything with them.)
+ SmallVector<Value> values;
+ llvm::append_range(values, regionBranchOp->getResults());
+ for (Region &r : op->getRegions())
+ llvm::append_range(values, r.getArguments());
+
+ bool changed = false;
+ for (Value value : values) {
+ if (value.use_empty())
+ continue;
+ DenseSet<Value> possibleValues =
+ regionBranchOp.computePossibleValuesOfSuccessorInput(value);
+ if (possibleValues.size() == 1 && *possibleValues.begin() != value &&
+ isDefinedBefore(*possibleValues.begin(), value)) {
+ // Value is same as another value.
+ rewriter.replaceAllUsesWith(value, *possibleValues.begin());
+ changed = true;
+ }
+ }
+ return success(changed);
+ }
+};
+
+/// Lookup a bit vector in the given mapping (DenseMap). If the key was not
+/// found, create a new bit vector with the given size and initialize it with
+/// false.
+template <typename MappingTy, typename KeyTy>
+static BitVector &lookupOrCreateBitVector(MappingTy &mapping, KeyTy key,
+ unsigned size) {
+ return mapping.try_emplace(key, size, false).first->second;
+}
+
+/// Pattern to remove dead values from region branch ops.
+struct RemoveDeadValues : public RewritePattern {
+ RemoveDeadValues(MLIRContext *context, StringRef name,
+ PatternBenefit benefit = 1)
+ : RewritePattern(name, benefit, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+
+ // Compute tied values: values that must come as a set. If you remove one,
+ // you must remove all.
+ RegionBranchSuccessorMapping operandToInputs;
+ regionBranchOp.getSuccessorOperandInputMapping(operandToInputs);
+ llvm::EquivalenceClasses<Value> tiedSuccessorInputs;
+ for (const auto &[operand, inputs] : operandToInputs) {
+ assert(!inputs.empty() && "expected non-empty inputs");
+ Value firstInput = inputs.front();
+ tiedSuccessorInputs.insert(firstInput);
+ for (Value nextInput : llvm::drop_begin(inputs))
+ tiedSuccessorInputs.unionSets(firstInput, nextInput);
+ }
+
+ // Determine which values to remove and group them by block and operation.
+ SmallVector<Value> valuesToRemove;
+ DenseMap<Block *, BitVector> blockArgsToRemove;
+ DenseMap<Operation *, BitVector> resultsToRemove;
+ for (auto it = tiedSuccessorInputs.begin(), e = tiedSuccessorInputs.end();
+ it != e; ++it) {
+ if (!(*it)->isLeader())
+ continue;
+
+ // Value can be removed if it is dead and all other tied values are also
+ // dead.
+ bool allDead = true;
+ for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
+ memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
+ if (!memberIt->use_empty()) {
+ allDead = false;
+ break;
+ }
+ }
+ if (!allDead)
+ continue;
+
+ // Group values by block and operation.
+ for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
+ memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
+ if (auto arg = dyn_cast<BlockArgument>(*memberIt)) {
+ BitVector &vector =
+ lookupOrCreateBitVector(blockArgsToRemove, arg.getOwner(),
+ arg.getOwner()->getNumArguments());
+ vector.set(arg.getArgNumber());
+ } else {
+ OpResult result = cast<OpResult>(*memberIt);
+ BitVector &vector =
+ lookupOrCreateBitVector(resultsToRemove, result.getDefiningOp(),
+ result.getDefiningOp()->getNumResults());
+ vector.set(result.getResultNumber());
+ }
+ valuesToRemove.push_back(*memberIt);
+ }
+ }
+
+ if (valuesToRemove.empty())
+ return rewriter.notifyMatchFailure(op, "no values to remove");
+
+ // Find operands that must be removed together with the values.
+ RegionBranchInverseSuccessorMapping inputsToOperands;
+ regionBranchOp.getSuccessorInputOperandMapping(inputsToOperands);
+ DenseMap<Operation *, llvm::BitVector> operandsToRemove;
+ for (Value value : valuesToRemove) {
+ for (OpOperand *operand : inputsToOperands[value]) {
+ BitVector &vector =
+ lookupOrCreateBitVector(operandsToRemove, operand->getOwner(),
+ operand->getOwner()->getNumOperands());
+ vector.set(operand->getOperandNumber());
+ }
+ }
+
+ // Erase operands.
+ for (auto &pair : operandsToRemove) {
+ Operation *op = pair.first;
+ BitVector &operands = pair.second;
+ rewriter.modifyOpInPlace(op, [&]() { op->eraseOperands(operands); });
+ }
+
+ // Erase block arguments.
+ for (auto &pair : blockArgsToRemove) {
+ Block *block = pair.first;
+ BitVector &blockArg = pair.second;
+ rewriter.modifyOpInPlace(block->getParentOp(),
+ [&]() { block->eraseArguments(blockArg); });
+ }
+
+ // Erase op results.
+ // TODO: Can we move this to RewriterBase, so we have a uniform API,
+ // similar to eraseArguments?
+ for (auto [op, resultsToErase] : resultsToRemove) {
+ rewriter.setInsertionPoint(op);
+ SmallVector<Type> newResultTypes;
+ for (OpResult result : op->getResults())
+ if (!resultsToErase[result.getResultNumber()])
+ newResultTypes.push_back(result.getType());
+ OperationState state(op->getLoc(), op->getName().getStringRef(),
+ op->getOperands(), newResultTypes, op->getAttrs());
+ for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
----------------
Hardcode84 wrote:
nit: `llvm::seq`
https://github.com/llvm/llvm-project/pull/174094
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits