================
@@ -583,3 +624,283 @@ Region *mlir::getEnclosingRepetitiveRegion(Value value) {
   LDBG() << "No enclosing repetitive region found for value";
   return nullptr;
 }
+
+/// Is a defined before b?
+static bool isDefinedBefore(Value a, Value b) {
+  Region *aRegion = a.getParentRegion();
+  Region *bRegion = b.getParentRegion();
+
+  if (aRegion->getParentOp()->isProperAncestor(bRegion->getParentOp())) {
+    return true;
+  }
+  if (aRegion == bRegion) {
+    Block *aBlock = a.getParentBlock();
+    Block *bBlock = b.getParentBlock();
+    if (aBlock != bBlock)
+      return false;
+    if (isa<BlockArgument>(a))
+      return true;
+    if (isa<BlockArgument>(b))
+      return false;
+    return a.getDefiningOp()->isBeforeInBlock(b.getDefiningOp());
+  }
+
+  return false;
+}
+
+namespace {
+// Try to make successor inputs dead by replacing their uses with values that
+// are not successor inputs. This pattern enables additional canonicalization
+// opportunities for RemoveDeadValues.
+struct RemoveUsesOfIdenticalValues : public RewritePattern {
+  RemoveUsesOfIdenticalValues(MLIRContext *context, StringRef name,
+                              PatternBenefit benefit = 1)
+      : RewritePattern(name, benefit, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+
+    // Gather all potential successor inputs. (Other values may also be
+    // included, but we're not doing anything with them.)
+    SmallVector<Value> values;
+    llvm::append_range(values, regionBranchOp->getResults());
+    for (Region &r : op->getRegions())
+      llvm::append_range(values, r.getArguments());
+
+    bool changed = false;
+    for (Value value : values) {
+      if (value.use_empty())
+        continue;
+      DenseSet<Value> possibleValues =
+          regionBranchOp.computePossibleValuesOfSuccessorInput(value);
+      if (possibleValues.size() == 1 && *possibleValues.begin() != value &&
+          isDefinedBefore(*possibleValues.begin(), value)) {
+        // Value is same as another value.
+        rewriter.replaceAllUsesWith(value, *possibleValues.begin());
+        changed = true;
+      }
+    }
+    return success(changed);
+  }
+};
+
+/// Lookup a bit vector in the given mapping (DenseMap). If the key was not
+/// found, create a new bit vector with the given size and initialize it with
+/// false.
+template <typename MappingTy, typename KeyTy>
+static BitVector &lookupOrCreateBitVector(MappingTy &mapping, KeyTy key,
+                                          unsigned size) {
+  return mapping.try_emplace(key, size, false).first->second;
+}
+
+/// Pattern to remove dead values from region branch ops.
+struct RemoveDeadValues : public RewritePattern {
+  RemoveDeadValues(MLIRContext *context, StringRef name,
+                   PatternBenefit benefit = 1)
+      : RewritePattern(name, benefit, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+
+    // Compute tied values: values that must come as a set. If you remove one,
+    // you must remove all.
+    RegionBranchSuccessorMapping operandToInputs;
+    regionBranchOp.getSuccessorOperandInputMapping(operandToInputs);
+    llvm::EquivalenceClasses<Value> tiedSuccessorInputs;
+    for (const auto &[operand, inputs] : operandToInputs) {
+      assert(!inputs.empty() && "expected non-empty inputs");
+      Value firstInput = inputs.front();
+      tiedSuccessorInputs.insert(firstInput);
+      for (Value nextInput : llvm::drop_begin(inputs))
+        tiedSuccessorInputs.unionSets(firstInput, nextInput);
+    }
+
+    // Determine which values to remove and group them by block and operation.
+    SmallVector<Value> valuesToRemove;
+    DenseMap<Block *, BitVector> blockArgsToRemove;
+    DenseMap<Operation *, BitVector> resultsToRemove;
+    for (auto it = tiedSuccessorInputs.begin(), e = tiedSuccessorInputs.end();
+         it != e; ++it) {
+      if (!(*it)->isLeader())
+        continue;
+
+      // Value can be removed if it is dead and all other tied values are also
+      // dead.
+      bool allDead = true;
+      for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
+           memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
+        if (!memberIt->use_empty()) {
+          allDead = false;
+          break;
+        }
+      }
+      if (!allDead)
+        continue;
+
+      // Group values by block and operation.
+      for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
+           memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
+        if (auto arg = dyn_cast<BlockArgument>(*memberIt)) {
+          BitVector &vector =
+              lookupOrCreateBitVector(blockArgsToRemove, arg.getOwner(),
+                                      arg.getOwner()->getNumArguments());
+          vector.set(arg.getArgNumber());
+        } else {
+          OpResult result = cast<OpResult>(*memberIt);
+          BitVector &vector =
+              lookupOrCreateBitVector(resultsToRemove, result.getDefiningOp(),
+                                      result.getDefiningOp()->getNumResults());
+          vector.set(result.getResultNumber());
+        }
+        valuesToRemove.push_back(*memberIt);
+      }
+    }
+
+    if (valuesToRemove.empty())
+      return rewriter.notifyMatchFailure(op, "no values to remove");
+
+    // Find operands that must be removed together with the values.
+    RegionBranchInverseSuccessorMapping inputsToOperands;
+    regionBranchOp.getSuccessorInputOperandMapping(inputsToOperands);
+    DenseMap<Operation *, llvm::BitVector> operandsToRemove;
+    for (Value value : valuesToRemove) {
+      for (OpOperand *operand : inputsToOperands[value]) {
+        BitVector &vector =
+            lookupOrCreateBitVector(operandsToRemove, operand->getOwner(),
+                                    operand->getOwner()->getNumOperands());
+        vector.set(operand->getOperandNumber());
+      }
+    }
+
+    // Erase operands.
+    for (auto &pair : operandsToRemove) {
+      Operation *op = pair.first;
+      BitVector &operands = pair.second;
+      rewriter.modifyOpInPlace(op, [&]() { op->eraseOperands(operands); });
+    }
+
+    // Erase block arguments.
+    for (auto &pair : blockArgsToRemove) {
+      Block *block = pair.first;
+      BitVector &blockArg = pair.second;
+      rewriter.modifyOpInPlace(block->getParentOp(),
+                               [&]() { block->eraseArguments(blockArg); });
+    }
+
+    // Erase op results.
+    // TODO: Can we move this to RewriterBase, so we have a uniform API,
+    // similar to eraseArguments?
+    for (auto [op, resultsToErase] : resultsToRemove) {
+      rewriter.setInsertionPoint(op);
+      SmallVector<Type> newResultTypes;
+      for (OpResult result : op->getResults())
+        if (!resultsToErase[result.getResultNumber()])
+          newResultTypes.push_back(result.getType());
+      OperationState state(op->getLoc(), op->getName().getStringRef(),
+                           op->getOperands(), newResultTypes, op->getAttrs());
+      for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
----------------
matthias-springer wrote:

This is no longer needed due to #174152.

https://github.com/llvm/llvm-project/pull/174094
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to