srishti-pm updated this revision to Diff 427964.
srishti-pm added a comment.

Fixing a comment typo and enhancing the commit summary even further.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D124750/new/

https://reviews.llvm.org/D124750

Files:
  clang/docs/tools/clang-formatted-files.txt
  mlir/include/mlir/Transforms/CommutativityUtils.h
  mlir/lib/Transforms/Utils/CMakeLists.txt
  mlir/lib/Transforms/Utils/CommutativityUtils.cpp
  mlir/test/Transforms/test-commutativity-utils.mlir
  mlir/test/lib/Dialect/Test/TestOps.td
  mlir/test/lib/Transforms/CMakeLists.txt
  mlir/test/lib/Transforms/TestCommutativityUtils.cpp
  mlir/tools/mlir-opt/mlir-opt.cpp

Index: mlir/tools/mlir-opt/mlir-opt.cpp
===================================================================
--- mlir/tools/mlir-opt/mlir-opt.cpp
+++ mlir/tools/mlir-opt/mlir-opt.cpp
@@ -56,6 +56,7 @@
 void registerVectorizerTestPass();
 
 namespace test {
+void registerCommutativityUtils();
 void registerConvertCallOpPass();
 void registerInliner();
 void registerMemRefBoundCheck();
@@ -146,6 +147,7 @@
   registerVectorizerTestPass();
   registerTosaTestQuantUtilAPIPass();
 
+  mlir::test::registerCommutativityUtils();
   mlir::test::registerConvertCallOpPass();
   mlir::test::registerInliner();
   mlir::test::registerMemRefBoundCheck();
Index: mlir/test/lib/Transforms/TestCommutativityUtils.cpp
===================================================================
--- /dev/null
+++ mlir/test/lib/Transforms/TestCommutativityUtils.cpp
@@ -0,0 +1,67 @@
+//===- TestCommutativityUtils.cpp - Pass to test the commutativity utility-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass tests the functionality of the commutativity utility.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/CommutativityUtils.h"
+
+#include "TestDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace test;
+
+namespace {
+
+struct SmallPattern : public OpRewritePattern<TestCommutativeOp> {
+  using OpRewritePattern<TestCommutativeOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(TestCommutativeOp testCommOp,
+                                PatternRewriter &rewriter) const override {
+    sortCommutativeOperands(testCommOp.getOperation(), rewriter);
+    return success();
+  }
+};
+
+struct LargePattern : public OpRewritePattern<TestLargeCommutativeOp> {
+  using OpRewritePattern<TestLargeCommutativeOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(TestLargeCommutativeOp testLargeCommOp,
+                                PatternRewriter &rewriter) const override {
+    sortCommutativeOperands(testLargeCommOp.getOperation(), rewriter);
+    return success();
+  }
+};
+
+struct CommutativityUtils
+    : public PassWrapper<CommutativityUtils, OperationPass<FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CommutativityUtils)
+
+  StringRef getArgument() const final { return "test-commutativity-utils"; }
+  StringRef getDescription() const final {
+    return "Test the functionality of the commutativity utility";
+  }
+
+  void runOnOperation() override {
+    auto func = getOperation();
+    auto *context = &getContext();
+
+    RewritePatternSet patterns(context);
+    patterns.add<LargePattern, SmallPattern>(context);
+
+    (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerCommutativityUtils() { PassRegistration<CommutativityUtils>(); }
+} // namespace test
+} // namespace mlir
Index: mlir/test/lib/Transforms/CMakeLists.txt
===================================================================
--- mlir/test/lib/Transforms/CMakeLists.txt
+++ mlir/test/lib/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRTestTransforms
+  TestCommutativityUtils.cpp
   TestConstantFold.cpp
   TestControlFlowSink.cpp
   TestInlining.cpp
Index: mlir/test/lib/Dialect/Test/TestOps.td
===================================================================
--- mlir/test/lib/Dialect/Test/TestOps.td
+++ mlir/test/lib/Dialect/Test/TestOps.td
@@ -1101,11 +1101,21 @@
   let hasFolder = 1;
 }
 
+def TestAddIOp : TEST_Op<"addi"> {
+  let arguments = (ins I32:$op1, I32:$op2);
+  let results = (outs I32);
+}
+
 def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> {
   let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4);
   let results = (outs I32);
 }
 
+def TestLargeCommutativeOp : TEST_Op<"op_large_commutative", [Commutative]> {
+  let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4, I32:$op5, I32:$op6, I32:$op7);
+  let results = (outs I32);
+}
+
 def TestCommutative2Op : TEST_Op<"op_commutative2", [Commutative]> {
   let arguments = (ins I32:$op1, I32:$op2);
   let results = (outs I32);
Index: mlir/test/Transforms/test-commutativity-utils.mlir
===================================================================
--- /dev/null
+++ mlir/test/Transforms/test-commutativity-utils.mlir
@@ -0,0 +1,116 @@
+// RUN: mlir-opt %s -test-commutativity-utils | FileCheck %s
+
+// CHECK-LABEL: @test_small_pattern_1
+func @test_small_pattern_1(%arg0 : i32) -> i32 {
+  // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant
+  %0 = arith.constant 45 : i32
+
+  // CHECK-NEXT: %[[TEST_ADD:.*]] = "test.addi"
+  %1 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
+
+  // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi
+  %2 = arith.addi %arg0, %arg0 : i32
+
+  // CHECK-NEXT: %[[ARITH_MUL:.*]] = arith.muli
+  %3 = arith.muli %arg0, %arg0 : i32
+
+  // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARITH_ADD]], %[[ARITH_MUL]], %[[TEST_ADD]], %[[ARITH_CONST]])
+  %result = "test.op_commutative"(%0, %1, %2, %3): (i32, i32, i32, i32) -> i32
+
+  // CHECK-NEXT: return %[[RESULT]]
+  return %result : i32
+}
+
+// CHECK-LABEL: @test_small_pattern_2
+// CHECK-SAME: (%[[ARG0:.*]]: i32
+func @test_small_pattern_2(%arg0 : i32) -> i32 {
+  // CHECK-NEXT: %[[TEST_CONST:.*]] = "test.constant"
+  %0 = "test.constant"() {value = 0 : i32} : () -> i32
+
+  // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant
+  %1 = arith.constant 0 : i32
+
+  // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi
+  %2 = arith.addi %arg0, %arg0 : i32
+
+  // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARITH_ADD]], %[[ARG0]], %[[ARITH_CONST]], %[[TEST_CONST]])
+  %result = "test.op_commutative"(%0, %1, %2, %arg0): (i32, i32, i32, i32) -> i32
+
+  // CHECK-NEXT: return %[[RESULT]]
+  return %result : i32
+}
+
+// CHECK-LABEL: @test_large_pattern
+func @test_large_pattern(%arg0 : i32, %arg1 : i32) -> i32 {
+  // CHECK-NEXT: arith.divsi
+  %0 = arith.divsi %arg0, %arg1 : i32
+
+  // CHECK-NEXT: arith.divsi
+  %1 = arith.divsi %0, %arg0 : i32
+
+  // CHECK-NEXT: arith.divsi
+  %2 = arith.divsi %1, %arg1 : i32
+
+  // CHECK-NEXT: arith.addi
+  %3 = arith.addi %1, %arg1 : i32
+
+  // CHECK-NEXT: arith.subi
+  %4 = arith.subi %2, %3 : i32
+
+  // CHECK-NEXT: "test.addi"
+  %5 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
+
+  // CHECK-NEXT: %[[VAL6:.*]] = arith.divsi
+  %6 = arith.divsi %4, %5 : i32
+
+  // CHECK-NEXT: arith.divsi
+  %7 = arith.divsi %1, %arg1 : i32
+
+  // CHECK-NEXT: %[[VAL8:.*]] = arith.muli
+  %8 = arith.muli %1, %arg1 : i32
+
+  // CHECK-NEXT: %[[VAL9:.*]] = arith.subi
+  %9 = arith.subi %7, %8 : i32
+
+  // CHECK-NEXT: "test.addi"
+  %10 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
+
+  // CHECK-NEXT: %[[VAL11:.*]] = arith.divsi
+  %11 = arith.divsi %9, %10 : i32
+
+  // CHECK-NEXT: %[[VAL12:.*]] = arith.divsi
+  %12 = arith.divsi %6, %arg1 : i32
+
+  // CHECK-NEXT: arith.subi
+  %13 = arith.subi %arg1, %arg0 : i32
+
+  // CHECK-NEXT: "test.op_commutative"(%[[VAL12]], %[[VAL12]], %[[VAL8]], %[[VAL9]])
+  %14 = "test.op_commutative"(%12, %9, %12, %8): (i32, i32, i32, i32) -> i32
+
+  // CHECK-NEXT: %[[VAL15:.*]] = arith.divsi
+  %15 = arith.divsi %13, %14 : i32
+
+  // CHECK-NEXT: %[[VAL16:.*]] = arith.addi
+  %16 = arith.addi %2, %15 : i32
+
+  // CHECK-NEXT: arith.subi
+  %17 = arith.subi %16, %arg1 : i32
+
+  // CHECK-NEXT: "test.addi"
+  %18 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
+
+  // CHECK-NEXT: %[[VAL19:.*]] = arith.divsi
+  %19 = arith.divsi %17, %18 : i32
+
+  // CHECK-NEXT: "test.addi"
+  %20 = "test.addi"(%arg0, %16): (i32, i32) -> i32
+
+  // CHECK-NEXT: %[[VAL21:.*]] = arith.divsi
+  %21 = arith.divsi %17, %20 : i32
+
+  // CHECK-NEXT: %[[RESULT:.*]] = "test.op_large_commutative"(%[[VAL16]], %[[VAL21]], %[[VAL19]], %[[VAL19]], %[[VAL6]], %[[VAL11]], %[[VAL15]])
+  %result = "test.op_large_commutative"(%16, %6, %11, %15, %19, %19, %21): (i32, i32, i32, i32, i32, i32, i32) -> i32
+
+  // CHECK-NEXT: return %[[RESULT]]
+  return %result : i32
+}
Index: mlir/lib/Transforms/Utils/CommutativityUtils.cpp
===================================================================
--- /dev/null
+++ mlir/lib/Transforms/Utils/CommutativityUtils.cpp
@@ -0,0 +1,399 @@
+//===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a utility that is intended to be used inside a pass or
+// an individual pattern to simplify the matching of commutative operations.
+// Note that this utility can also be used inside PDL patterns in conjunction
+// with the `pdl.apply_native_rewrite` op.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/CommutativityUtils.h"
+
+#include "mlir/IR/PatternMatch.h"
+#include <queue>
+
+#define DEBUG_TYPE "commutativity-utils"
+
+using namespace mlir;
+
+/// Stores the BFS traversal information of an operand.
+struct OperandBFS {
+  /// Stores the queue of ancestors of the BFS traversal of an operand at a
+  /// particular point in time.
+  std::queue<Operation *> ancestorQueue;
+
+  /// Stores the list of visited ancestors of the BFS traversal of an operand at
+  /// a particular point in time.
+  DenseSet<Operation *> visitedAncestors;
+
+  /// Stores the key corresponding to the BFS traversal of an operand at a
+  /// particular point in time.
+  /// Some examples:
+  /// 1. If the BFS has seen `arith.addi`,
+  ///    then,
+  ///    the key will store the string:
+  ///      "1arith.addi".
+  /// 2. If the BFS has seen `arg5`,
+  ///    then,
+  ///    the key will store the string:
+  ///       "2".
+  /// 3. If the BFS has seen `arith.constant`,
+  ///    then,
+  ///    the key will store the string:
+  ///       "3arith.constant".
+  /// 4. If the BFS has seen `arith.addi`, `test.constant`, `scf.if`, `tf.Add`,
+  ///    `arith.constant`, and `arg5` (in BFS order),
+  ///    then,
+  ///    the key will store the string:
+  ///       "1arith.addi3test.constant1scf.if1tf.Add3arith.constant2".
+  ///
+  /// Such a definition of "key" will allow the ascending order of keys of
+  /// different operands to be such the (1) ones defined by non-constant-like
+  /// ops come first, followed by (2) block arguments, which are finally
+  /// followed by the (3) ones defined by constant-like ops. In addition to
+  /// this, within the categories (1) and (3), the order of operands is
+  /// alphabetical w.r.t. the dialect name and op name.
+  ///
+  /// Further, as an example to demonstrate the comparision of keys, note that
+  /// if we have the following commutative op (foo.op):
+  ///   e = foo.div f, g
+  ///   c = foo.constant
+  ///   b = foo.add e, d
+  ///   a = foo.add c, d
+  ///   s = foo.op a, b,
+  /// then,
+  /// the key associated with operand `a` will be "1foo.add3foo.constant", and,
+  /// the key associated with operand `b` will be "1foo.add1foo.div",
+  /// and thus,
+  /// key of `a` > key of `b`,
+  ///
+  /// which means that a "sorted" foo.op would look like:
+  ///   s = foo.op b, a (instead of a, b).
+  std::string key = "";
+
+  /// Stores true iff the operand has been assigned a sorted position yet.
+  bool isAssignedSortedPosition = false;
+
+  /// Push an ancestor into the operand's BFS information structure. This
+  /// entails it being pushed into the queue (always) and inserted into the
+  /// "visited ancestors" list (iff it is not null, i.e., corresponds to an op
+  /// rather than a block argument).
+  void pushAncestor(Operation *ancestor) {
+    ancestorQueue.push(ancestor);
+    if (ancestor)
+      visitedAncestors.insert(ancestor);
+    return;
+  }
+
+  /// Pop the ancestor from the front of the queue.
+  void popAncestor() {
+    assert(!ancestorQueue.empty() &&
+           "to pop the ancestor from the front of the queue, the ancestor "
+           "queue should be non-empty");
+    ancestorQueue.pop();
+    return;
+  }
+
+  /// Return the ancestor at the front of the queue.
+  Operation *frontAncestor() {
+    assert(!ancestorQueue.empty() &&
+           "to access the ancestor at the front of the queue, the ancestor "
+           "queue should be non-empty");
+    return ancestorQueue.front();
+  }
+};
+
+/// Returns true iff at least one unassigned operand exists. An unassigned
+/// operand refers to one which has not been assigned a sorted position yet.
+static bool
+hasAtLeastOneUnassignedOperand(SmallVector<OperandBFS *, 2> bfsOfOperands) {
+  for (OperandBFS *bfsOfOperand : bfsOfOperands) {
+    if (!bfsOfOperand->isAssignedSortedPosition)
+      return true;
+  }
+  return false;
+}
+
+/// Goes through all the unassigned operands of `bfsOfOperands` and:
+/// 1. Stores the indices of the ones with the smallest key in
+/// `smallestKeyIndices`,
+/// 2. Stores the indices of the ones with the largest key in
+/// `largestKeyIndices`,
+/// 3. Sets `hasASingleOperandWithSmallestKey` as true if exactly one of them
+/// has the smallest key (and as false otherwise), AND,
+/// 4. Sets `hasASingleOperandWithLargestKey` as true if exactly one of them has
+/// the largest key (and as false otherwise).
+static void getIndicesOfUnassignedOperandsWithSmallestAndLargestKeys(
+    SmallVector<OperandBFS *, 2> bfsOfOperands,
+    DenseSet<unsigned> &smallestKeyIndices,
+    DenseSet<unsigned> &largestKeyIndices,
+    bool &hasASingleOperandWithSmallestKey,
+    bool &hasASingleOperandWithLargestKey) {
+  bool foundAnUnassignedOperand = false;
+
+  // Compute the smallest and largest keys present among the unassigned operands
+  // of `bfsOfOperands`.
+  std::string smallestKey, largestKey;
+  for (OperandBFS *bfsOfOperand : bfsOfOperands) {
+    if (bfsOfOperand->isAssignedSortedPosition)
+      continue;
+
+    std::string currentKey = bfsOfOperand->key;
+    if (!foundAnUnassignedOperand) {
+      foundAnUnassignedOperand = true;
+      smallestKey = currentKey;
+      largestKey = currentKey;
+      continue;
+    }
+    if (smallestKey > currentKey)
+      smallestKey = currentKey;
+    if (largestKey < currentKey)
+      largestKey = currentKey;
+  }
+
+  // If there is no unassigned operand, assign the necessary values to the input
+  // arguments and return.
+  if (!foundAnUnassignedOperand) {
+    hasASingleOperandWithSmallestKey = false;
+    hasASingleOperandWithLargestKey = false;
+    return;
+  }
+
+  // Populate `smallestKeyIndices` and `largestKeyIndices` and set
+  // `hasASingleOperandWithSmallestKey` and `hasASingleOperandWithLargestKey`
+  // accordingly.
+  bool smallestKeyFound = false;
+  bool largestKeyFound = false;
+  hasASingleOperandWithSmallestKey = true;
+  hasASingleOperandWithLargestKey = true;
+  for (auto indexedBfsOfOperand : llvm::enumerate(bfsOfOperands)) {
+    OperandBFS *bfsOfOperand = indexedBfsOfOperand.value();
+    if (bfsOfOperand->isAssignedSortedPosition)
+      continue;
+
+    unsigned index = indexedBfsOfOperand.index();
+    std::string currentKey = bfsOfOperand->key;
+
+    if (smallestKey == currentKey) {
+      smallestKeyIndices.insert(index);
+      if (smallestKeyFound)
+        hasASingleOperandWithSmallestKey = false;
+      smallestKeyFound = true;
+    }
+
+    if (largestKey == currentKey) {
+      largestKeyIndices.insert(index);
+      if (largestKeyFound)
+        hasASingleOperandWithLargestKey = false;
+      largestKeyFound = true;
+    }
+  }
+  return;
+}
+
+/// Update the key associated with each unassigned operand in `bfsOfOperands`.
+/// Updating a key entails making it up-to-date with its associated operand's
+/// BFS traversal that has happened till that point in time. Note that a key
+/// directly reflects the BFS and thus needs to be updated after every change in
+/// the BFS queue, as the traversal happens.
+static void updateKeys(SmallVector<OperandBFS *, 2> bfsOfOperands) {
+  for (OperandBFS *bfsOfOperand : bfsOfOperands) {
+    if (bfsOfOperand->isAssignedSortedPosition ||
+        bfsOfOperand->ancestorQueue.empty())
+      continue;
+
+    Operation *frontAncestor = bfsOfOperand->frontAncestor();
+    if (!frontAncestor) {
+      // When the front ancestor is a block argument, we concatenate the old key
+      // with such a value that allows its corresponding operand to be
+      // positioned between operands defined by non-constant-like and
+      // constant-like operations.
+      bfsOfOperand->key = (Twine(bfsOfOperand->key) + Twine("2")).str();
+    } else if (frontAncestor->hasTrait<OpTrait::ConstantLike>()) {
+      // When the front ancestor is a constant-like operation, we concatenate
+      // the old key with such a value that allows its corresponding operand to
+      // be positioned after operands defined by non-constant-like operations or
+      // block arguments (while maintaining that among constant-like operations,
+      // the corresponding operands are positioned alphabetically).
+      bfsOfOperand->key = (Twine(bfsOfOperand->key) + Twine("3") +
+                           std::string(frontAncestor->getName().getStringRef()))
+                              .str();
+    } else {
+      // When the front ancestor is a non-constant-like operation, we
+      // concatenate the old key with such a value that allows its corresponding
+      // operand to be positioned before block arguments or operands defined by
+      // constant-like operations (while maintaining that among
+      // non-constant-like operations, the corresponding operands are positioned
+      // alphabetically).
+      bfsOfOperand->key = (Twine(bfsOfOperand->key) + Twine("1") +
+                           std::string(frontAncestor->getName().getStringRef()))
+                              .str();
+    }
+  }
+  return;
+}
+
+/// Rewrite `op`, i.e., rearrange its operands in a "sorted" order.
+/// The operands of an op are considered to be "sorted" iff:
+///   1. The op is not commutative, OR,
+///   2. It is commutative and its operands are in ascending order of the "keys"
+///      associated with them.
+///
+/// Note that `operandDefOps` stores the list of ops defining its operands (in
+/// the order in which they appear in `op`). If an operand is a block argument,
+/// the op defining it stores null.
+static void
+rewriteCommutativeOperands(Operation *op,
+                           SmallVector<Operation *, 2> operandDefOps,
+                           PatternRewriter &rewriter) {
+  // If `op` is not commutative, do nothing.
+  if (!op->hasTrait<OpTrait::IsCommutative>())
+    return;
+
+  // `bfsOfOperands` stores the BFS traversal information of each operand of
+  // `op`. For each operand, this information comprises a queue of ancestors
+  // being visited during the BFS (at a particular point in time), a list of
+  // visited ancestors (at a particular point in time), its associated key (at a
+  // particular point in time), and whether or not the operand has been assigned
+  // a sorted position yet.
+  SmallVector<OperandBFS *, 2> bfsOfOperands;
+
+  // Initially, each operand's ancestor queue contains the op defining it (which
+  // is considered its first ancestor). Thus, it acts as the starting point for
+  // that operand's BFS traversal.
+  for (Operation *operandDefOp : operandDefOps) {
+    OperandBFS *bfsOfOperand = new OperandBFS();
+    bfsOfOperand->pushAncestor(operandDefOp);
+    bfsOfOperands.push_back(bfsOfOperand);
+  }
+
+  // Since none of the operands have been assigned a sorted position yet, the
+  // smallest unassigned position is set as zero and the largest one is set as
+  // the number of operands in `op` minus one (N - 1). This is because each
+  // operand will be assigned a sorted position between 0 and (N - 1), both
+  // inclusive.
+  unsigned numOperands = op->getNumOperands();
+  unsigned smallestUnassignedPosition = 0;
+  unsigned largestUnassignedPosition = numOperands - 1;
+
+  // `sortedOperands` will store the list of `op`'s operands in sorted order.
+  // At first, all elements in it are initialized as null.
+  SmallVector<Value, 2> sortedOperands;
+  while (numOperands) {
+    sortedOperands.push_back(nullptr);
+    numOperands--;
+  }
+
+  // We perform the BFS traversals of all operands parallelly until each of them
+  // is assigned a sorted position. During the traversals, we try to assign a
+  // sorted position to an operand as soon as it is possible (based on a
+  // comparision of its traversal with the other traversals at that particular
+  // point in time).
+  while (hasAtLeastOneUnassignedOperand(bfsOfOperands)) {
+    // Update the keys corresponding to all unassigned operands.
+    updateKeys(bfsOfOperands);
+
+    // Stores the indices of the unassigned operands whose key is the smallest.
+    DenseSet<unsigned> smallestKeyIndices;
+    // Stores the indices of the unassigned operands whose key is the largest.
+    DenseSet<unsigned> largestKeyIndices;
+
+    // Stores true iff there is a single unassigned operand that has the
+    // smallest key.
+    bool hasASingleOperandWithSmallestKey;
+    // Stores true iff there is a single unassigned operand that has the largest
+    // key.
+    bool hasASingleOperandWithLargestKey;
+
+    getIndicesOfUnassignedOperandsWithSmallestAndLargestKeys(
+        bfsOfOperands, smallestKeyIndices, largestKeyIndices,
+        hasASingleOperandWithSmallestKey, hasASingleOperandWithLargestKey);
+
+    // Go through each of the unassigned operands and try to assign it a sorted
+    // position if possible.
+    for (auto indexedBfsOfOperand : llvm::enumerate(bfsOfOperands)) {
+      OperandBFS *bfsOfOperand = indexedBfsOfOperand.value();
+      if (bfsOfOperand->isAssignedSortedPosition)
+        continue;
+
+      unsigned index = indexedBfsOfOperand.index();
+
+      // If an unassigned operand has the smallest key and:
+      // 1. It is the only operand with the smallest key, OR,
+      // 2. Its BFS is complete,
+      // then,
+      // this operand is assigned the `smallestUnassignedPosition` (which will
+      // be its new position in the rearranged `op`).
+      //
+      // Likewise,
+      //
+      // If an unassigned operand has the largest key and:
+      // 1. It is the only operand with the largest key, OR,
+      // 2. Its BFS is complete,
+      // then,
+      // this operand is assigned the `largestUnassignedPosition` (which will be
+      // its new position in the rearranged `op`).
+      if (smallestKeyIndices.contains(index) &&
+          (hasASingleOperandWithSmallestKey ||
+           bfsOfOperand->ancestorQueue.empty())) {
+        bfsOfOperand->isAssignedSortedPosition = true;
+        sortedOperands[smallestUnassignedPosition] = op->getOperand(index);
+        smallestUnassignedPosition++;
+      } else if (largestKeyIndices.contains(index) &&
+                 (hasASingleOperandWithLargestKey ||
+                  bfsOfOperand->ancestorQueue.empty())) {
+        bfsOfOperand->isAssignedSortedPosition = true;
+        sortedOperands[largestUnassignedPosition] = op->getOperand(index);
+        largestUnassignedPosition--;
+      }
+
+      // Pop the front ancestor from the queue, if any, and then push its
+      // adjacent unvisited ancestors, if any, to the queue (the main body of
+      // the BFS algorithm).
+      if (bfsOfOperand->ancestorQueue.empty())
+        continue;
+      Operation *frontAncestor = bfsOfOperand->frontAncestor();
+      bfsOfOperand->popAncestor();
+      if (!frontAncestor)
+        continue;
+      for (Value operand : frontAncestor->getOperands()) {
+        Operation *thisOperandDefOp = operand.getDefiningOp();
+        if (!thisOperandDefOp ||
+            !bfsOfOperand->visitedAncestors.contains(thisOperandDefOp))
+          bfsOfOperand->pushAncestor(thisOperandDefOp);
+      }
+    }
+  }
+  rewriter.updateRootInPlace(op, [&] { op->setOperands(sortedOperands); });
+}
+
+/// Sorts `op`.
+/// "Sorting" `op` means to "sort" the ops defining each of its operands
+/// followed by rearranging its operands in the "sorted" order. Before the
+/// rearrangement, it is important to sort the ops defining its operands so that
+/// the rearrangement is deterministic. In other words, if these ops were not
+/// sorted, the rearrangement would be non-deterministic and would thus make
+/// this utility useless.
+void mlir::sortCommutativeOperands(Operation *op, PatternRewriter &rewriter) {
+  assert(op && "the input argument `op` must not be null");
+
+  // Before the operands of `op` are rearranged, the operations defining the
+  // operands of `op` are sorted.
+  SmallVector<Operation *, 2> operandDefOps;
+  for (Value operand : op->getOperands()) {
+    Operation *operandDefOp = operand.getDefiningOp();
+    operandDefOps.push_back(operandDefOp);
+    if (operandDefOp)
+      sortCommutativeOperands(operandDefOp, rewriter);
+  }
+
+  // Now, rewrite `op`, i.e, rearrange its operands in a "sorted" order.
+  rewriteCommutativeOperands(op, operandDefOps, rewriter);
+  return;
+}
Index: mlir/lib/Transforms/Utils/CMakeLists.txt
===================================================================
--- mlir/lib/Transforms/Utils/CMakeLists.txt
+++ mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_library(MLIRTransformUtils
+  CommutativityUtils.cpp
   ControlFlowSinkUtils.cpp
   DialectConversion.cpp
   FoldUtils.cpp
Index: mlir/include/mlir/Transforms/CommutativityUtils.h
===================================================================
--- /dev/null
+++ mlir/include/mlir/Transforms/CommutativityUtils.h
@@ -0,0 +1,28 @@
+//===- CommutativityUtils.h - Commutativity utilities -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file declares a utility that is intended to be used inside a pass
+// or an individual pattern to simplify the matching of commutative operations.
+// Note that this utility can also be used inside PDL patterns in conjunction
+// with the `pdl.apply_native_rewrite` op.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H
+#define MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H
+
+namespace mlir {
+
+class Operation;
+class PatternRewriter;
+
+void sortCommutativeOperands(Operation *op, PatternRewriter &rewriter);
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H
Index: clang/docs/tools/clang-formatted-files.txt
===================================================================
--- clang/docs/tools/clang-formatted-files.txt
+++ clang/docs/tools/clang-formatted-files.txt
@@ -7888,6 +7888,7 @@
 mlir/include/mlir/Tools/PDLL/ODS/Dialect.h
 mlir/include/mlir/Tools/PDLL/ODS/Operation.h
 mlir/include/mlir/Tools/PDLL/Parser/Parser.h
+mlir/include/mlir/Transforms/CommutativityUtils.h
 mlir/include/mlir/Transforms/ControlFlowSinkUtils.h
 mlir/include/mlir/Transforms/DialectConversion.h
 mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -8448,6 +8449,7 @@
 mlir/lib/Transforms/StripDebugInfo.cpp
 mlir/lib/Transforms/SymbolDCE.cpp
 mlir/lib/Transforms/SymbolPrivatize.cpp
+mlir/lib/Transforms/Utils/CommutativityUtils.cpp
 mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
 mlir/lib/Transforms/Utils/DialectConversion.cpp
 mlir/lib/Transforms/Utils/FoldUtils.cpp
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to