[Mlir-commits] [mlir] b508c56 - [MLIR] Add a utility to sort the operands of commutative ops

Jeff Niu llvmlistbot at llvm.org
Sat Jul 30 16:25:23 PDT 2022


Author: srishti-cb
Date: 2022-07-30T19:25:18-04:00
New Revision: b508c5649f5e21e17e9f5633236ec61c551803af

URL: https://github.com/llvm/llvm-project/commit/b508c5649f5e21e17e9f5633236ec61c551803af
DIFF: https://github.com/llvm/llvm-project/commit/b508c5649f5e21e17e9f5633236ec61c551803af.diff

LOG: [MLIR] Add a utility to sort the operands of commutative ops

Added a commutativity utility pattern and a function to populate it. The pattern sorts the operands of an op in ascending order of the "key" associated with each operand iff the op is commutative. This sorting is stable.

The function is intended to be used inside passes to simplify the matching of commutative operations. After the application of the above-mentioned pattern, since the commutative operands now have a deterministic order in which they occur in an op, the matching of large DAGs becomes much simpler, i.e., requires much less number of checks to be written by a user in her/his pattern matching function.

The "key" associated with an operand is the list of the "AncestorKeys" associated with the ancestors of this operand, in a breadth-first order.

The operand of any op is produced by a set of ops and block arguments. Each of these ops and block arguments is called an "ancestor" of this operand.

Now, the "AncestorKey" associated with:
1. A block argument is `{type: BLOCK_ARGUMENT, opName: ""}`.
2. A non-constant-like op, for example, `arith.addi`, is `{type: NON_CONSTANT_OP, opName: "arith.addi"}`.
3. A constant-like op, for example, `arith.constant`, is `{type: CONSTANT_OP, opName: "arith.constant"}`.

So, if an operand, say `A`, was produced as follows:

```
`<block argument>`  `<block argument>`
             \          /
              \        /
              `arith.subi`           `arith.constant`
                         \            /
                         `arith.addi`
                                |
                           returns `A`
```

Then, the block arguments and operations present in the backward slice of `A`, in the breadth-first order are:
`arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and `<block argument>`.

Thus, the "key" associated with operand `A` is:
```
{
 {type: NON_CONSTANT_OP, opName: "arith.addi"},
 {type: NON_CONSTANT_OP, opName: "arith.subi"},
 {type: CONSTANT_OP, opName: "arith.constant"},
 {type: BLOCK_ARGUMENT, opName: ""},
 {type: BLOCK_ARGUMENT, opName: ""}
}
```

Now, if "keyA" is the key associated with operand `A` and "keyB" is the key associated with operand `B`, then:
"keyA" < "keyB" iff:
1. In the first unequal pair of corresponding AncestorKeys, the AncestorKey in operand `A` is smaller, or,
2. Both the AncestorKeys in every pair are the same and the size of operand `A`'s "key" is smaller.

AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller ones are the ones with smaller op names (lexicographically).

---

Some examples of such a sorting:

Assume that the sorting is being applied to `foo.commutative`, which is a commutative op.

Example 1:

> %1 = foo.const 0
> %2 = foo.mul <block argument>, <block argument>
> %3 = foo.commutative %1, %2

Here,
1. The key associated with %1 is:
```
    {
     {CONSTANT_OP, "foo.const"}
    }
```
2. The key associated with %2 is:
```
    {
     {NON_CONSTANT_OP, "foo.mul"},
     {BLOCK_ARGUMENT, ""},
     {BLOCK_ARGUMENT, ""}
    }
```

The key of %2 < the key of %1
Thus, the sorted `foo.commutative` is:
> %3 = foo.commutative %2, %1

Example 2:

> %1 = foo.const 0
> %2 = foo.mul <block argument>, <block argument>
> %3 = foo.mul %2, %1
> %4 = foo.add %2, %1
> %5 = foo.commutative %1, %2, %3, %4

Here,
1. The key associated with %1 is:
```
    {
     {CONSTANT_OP, "foo.const"}
    }
```
2. The key associated with %2 is:
```
    {
     {NON_CONSTANT_OP, "foo.mul"},
     {BLOCK_ARGUMENT, ""}
    }
```
3. The key associated with %3 is:
```
    {
     {NON_CONSTANT_OP, "foo.mul"},
     {NON_CONSTANT_OP, "foo.mul"},
     {CONSTANT_OP, "foo.const"},
     {BLOCK_ARGUMENT, ""},
     {BLOCK_ARGUMENT, ""}
    }
```
4. The key associated with %4 is:
```
    {
     {NON_CONSTANT_OP, "foo.add"},
     {NON_CONSTANT_OP, "foo.mul"},
     {CONSTANT_OP, "foo.const"},
     {BLOCK_ARGUMENT, ""},
     {BLOCK_ARGUMENT, ""}
    }
```

Thus, the sorted `foo.commutative` is:
> %5 = foo.commutative %4, %3, %2, %1

Signed-off-by: Srishti Srivastava <srishti.srivastava at polymagelabs.com>

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D124750

Added: 
    mlir/include/mlir/Transforms/CommutativityUtils.h
    mlir/lib/Transforms/Utils/CommutativityUtils.cpp
    mlir/test/Transforms/test-commutativity-utils.mlir
    mlir/test/lib/Transforms/TestCommutativityUtils.cpp

Modified: 
    mlir/lib/Transforms/Utils/CMakeLists.txt
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/CommutativityUtils.h b/mlir/include/mlir/Transforms/CommutativityUtils.h
new file mode 100644
index 0000000000000..8306926160129
--- /dev/null
+++ b/mlir/include/mlir/Transforms/CommutativityUtils.h
@@ -0,0 +1,27 @@
+//===- CommutativityUtils.h - Commutativity utilities -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file declares a function to populate the commutativity utility
+// pattern. This function is intended to be used inside passes to simplify the
+// matching of commutative operations by fixing the order of their operands.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H
+#define MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+
+/// Populates the commutativity utility patterns.
+void populateCommutativityUtilsPatterns(RewritePatternSet &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H

diff  --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index 755e3196837d2..57307e665940a 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_library(MLIRTransformUtils
+  CommutativityUtils.cpp
   ControlFlowSinkUtils.cpp
   DialectConversion.cpp
   FoldUtils.cpp

diff  --git a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp
new file mode 100644
index 0000000000000..57cc3a2958256
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp
@@ -0,0 +1,317 @@
+//===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a commutativity utility pattern and a function to
+// populate this pattern. The function is intended to be used inside passes to
+// simplify the matching of commutative operations by fixing the order of their
+// operands.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/CommutativityUtils.h"
+
+#include <queue>
+
+using namespace mlir;
+
+/// The possible "types" of ancestors. Here, an ancestor is an op or a block
+/// argument present in the backward slice of a value.
+enum AncestorType {
+  /// Pertains to a block argument.
+  BLOCK_ARGUMENT,
+
+  /// Pertains to a non-constant-like op.
+  NON_CONSTANT_OP,
+
+  /// Pertains to a constant-like op.
+  CONSTANT_OP
+};
+
+/// Stores the "key" associated with an ancestor.
+struct AncestorKey {
+  /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on
+  /// the ancestor.
+  AncestorType type;
+
+  /// Holds the op name of the ancestor if its `type` is `NON_CONSTANT_OP` or
+  /// `CONSTANT_OP`. Else, holds "".
+  StringRef opName;
+
+  /// Constructor for `AncestorKey`.
+  AncestorKey(Operation *op) {
+    if (!op) {
+      type = BLOCK_ARGUMENT;
+    } else {
+      type =
+          op->hasTrait<OpTrait::ConstantLike>() ? CONSTANT_OP : NON_CONSTANT_OP;
+      opName = op->getName().getStringRef();
+    }
+  }
+
+  /// Overloaded operator `<` for `AncestorKey`.
+  ///
+  /// AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those
+  /// of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in
+  /// between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller
+  /// ones are the ones with smaller op names (lexicographically).
+  ///
+  /// TODO: Include other information like attributes, value type, etc., to
+  /// enhance this comparison. For example, currently this comparison doesn't
+  /// 
diff erentiate between `cmpi sle` and `cmpi sgt` or `addi (in i32)` and
+  /// `addi (in i64)`. Such an enhancement should only be done if the need
+  /// arises.
+  bool operator<(const AncestorKey &key) const {
+    return std::tie(type, opName) < std::tie(key.type, key.opName);
+  }
+};
+
+/// Stores a commutative operand along with its BFS traversal information.
+struct CommutativeOperand {
+  /// Stores the operand.
+  Value operand;
+
+  /// Stores the queue of ancestors of the operand's BFS traversal at a
+  /// particular point in time.
+  std::queue<Operation *> ancestorQueue;
+
+  /// Stores the list of ancestors that have been visited by the BFS traversal
+  /// at a particular point in time.
+  DenseSet<Operation *> visitedAncestors;
+
+  /// Stores the operand's "key". This "key" is defined as a list of the
+  /// "AncestorKeys" associated with the ancestors of this operand, in a
+  /// breadth-first order.
+  ///
+  /// So, if an operand, say `A`, was produced as follows:
+  ///
+  /// `<block argument>`  `<block argument>`
+  ///             \          /
+  ///              \        /
+  ///             `arith.subi`           `arith.constant`
+  ///                       \            /
+  ///                        `arith.addi`
+  ///                              |
+  ///                         returns `A`
+  ///
+  /// Then, the ancestors of `A`, in the breadth-first order are:
+  /// `arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and
+  /// `<block argument>`.
+  ///
+  /// Thus, the "key" associated with operand `A` is:
+  /// {
+  ///  {type: `NON_CONSTANT_OP`, opName: "arith.addi"},
+  ///  {type: `NON_CONSTANT_OP`, opName: "arith.subi"},
+  ///  {type: `CONSTANT_OP`, opName: "arith.constant"},
+  ///  {type: `BLOCK_ARGUMENT`, opName: ""},
+  ///  {type: `BLOCK_ARGUMENT`, opName: ""}
+  /// }
+  SmallVector<AncestorKey, 4> key;
+
+  /// Push an ancestor into the operand's BFS information structure. This
+  /// entails it being pushed into the queue (always) and inserted into the
+  /// "visited ancestors" list (iff it is an op rather than a block argument).
+  void pushAncestor(Operation *op) {
+    ancestorQueue.push(op);
+    if (op)
+      visitedAncestors.insert(op);
+    return;
+  }
+
+  /// Refresh the key.
+  ///
+  /// Refreshing a key entails making it up-to-date with the operand's BFS
+  /// traversal that has happened till that point in time, i.e, appending the
+  /// existing key with the front ancestor's "AncestorKey". Note that a key
+  /// directly reflects the BFS and thus needs to be refreshed during the
+  /// progression of the traversal.
+  void refreshKey() {
+    if (ancestorQueue.empty())
+      return;
+
+    Operation *frontAncestor = ancestorQueue.front();
+    AncestorKey frontAncestorKey(frontAncestor);
+    key.push_back(frontAncestorKey);
+    return;
+  }
+
+  /// Pop the front ancestor, if any, from the queue and then push its adjacent
+  /// unvisited ancestors, if any, to the queue (this is the main body of the
+  /// BFS algorithm).
+  void popFrontAndPushAdjacentUnvisitedAncestors() {
+    if (ancestorQueue.empty())
+      return;
+    Operation *frontAncestor = ancestorQueue.front();
+    ancestorQueue.pop();
+    if (!frontAncestor)
+      return;
+    for (Value operand : frontAncestor->getOperands()) {
+      Operation *operandDefOp = operand.getDefiningOp();
+      if (!operandDefOp || !visitedAncestors.contains(operandDefOp))
+        pushAncestor(operandDefOp);
+    }
+    return;
+  }
+};
+
+/// Sorts the operands of `op` in ascending order of the "key" associated with
+/// each operand iff `op` is commutative. This is a stable sort.
+///
+/// After the application of this pattern, since the commutative operands now
+/// have a deterministic order in which they occur in an op, the matching of
+/// large DAGs becomes much simpler, i.e., requires much less number of checks
+/// to be written by a user in her/his pattern matching function.
+///
+/// Some examples of such a sorting:
+///
+/// Assume that the sorting is being applied to `foo.commutative`, which is a
+/// commutative op.
+///
+/// Example 1:
+///
+/// %1 = foo.const 0
+/// %2 = foo.mul <block argument>, <block argument>
+/// %3 = foo.commutative %1, %2
+///
+/// Here,
+/// 1. The key associated with %1 is:
+///     `{
+///       {CONSTANT_OP, "foo.const"}
+///      }`
+/// 2. The key associated with %2 is:
+///     `{
+///       {NON_CONSTANT_OP, "foo.mul"},
+///       {BLOCK_ARGUMENT, ""},
+///       {BLOCK_ARGUMENT, ""}
+///      }`
+///
+/// The key of %2 < the key of %1
+/// Thus, the sorted `foo.commutative` is:
+/// %3 = foo.commutative %2, %1
+///
+/// Example 2:
+///
+/// %1 = foo.const 0
+/// %2 = foo.mul <block argument>, <block argument>
+/// %3 = foo.mul %2, %1
+/// %4 = foo.add %2, %1
+/// %5 = foo.commutative %1, %2, %3, %4
+///
+/// Here,
+/// 1. The key associated with %1 is:
+///     `{
+///       {CONSTANT_OP, "foo.const"}
+///      }`
+/// 2. The key associated with %2 is:
+///     `{
+///       {NON_CONSTANT_OP, "foo.mul"},
+///       {BLOCK_ARGUMENT, ""}
+///      }`
+/// 3. The key associated with %3 is:
+///     `{
+///       {NON_CONSTANT_OP, "foo.mul"},
+///       {NON_CONSTANT_OP, "foo.mul"},
+///       {CONSTANT_OP, "foo.const"},
+///       {BLOCK_ARGUMENT, ""},
+///       {BLOCK_ARGUMENT, ""}
+///      }`
+/// 4. The key associated with %4 is:
+///     `{
+///       {NON_CONSTANT_OP, "foo.add"},
+///       {NON_CONSTANT_OP, "foo.mul"},
+///       {CONSTANT_OP, "foo.const"},
+///       {BLOCK_ARGUMENT, ""},
+///       {BLOCK_ARGUMENT, ""}
+///      }`
+///
+/// Thus, the sorted `foo.commutative` is:
+/// %5 = foo.commutative %4, %3, %2, %1
+class SortCommutativeOperands : public RewritePattern {
+public:
+  SortCommutativeOperands(MLIRContext *context)
+      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {}
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    // Custom comparator for two commutative operands, which returns true iff
+    // the "key" of `constCommOperandA` < the "key" of `constCommOperandB`,
+    // i.e.,
+    // 1. In the first unequal pair of corresponding AncestorKeys, the
+    // AncestorKey in `constCommOperandA` is smaller, or,
+    // 2. Both the AncestorKeys in every pair are the same and the size of
+    // `constCommOperandA`'s "key" is smaller.
+    auto commutativeOperandComparator =
+        [](const std::unique_ptr<CommutativeOperand> &constCommOperandA,
+           const std::unique_ptr<CommutativeOperand> &constCommOperandB) {
+          if (constCommOperandA->operand == constCommOperandB->operand)
+            return false;
+
+          auto &commOperandA =
+              const_cast<std::unique_ptr<CommutativeOperand> &>(
+                  constCommOperandA);
+          auto &commOperandB =
+              const_cast<std::unique_ptr<CommutativeOperand> &>(
+                  constCommOperandB);
+
+          // Iteratively perform the BFS's of both operands until an order among
+          // them can be determined.
+          unsigned keyIndex = 0;
+          while (true) {
+            if (commOperandA->key.size() <= keyIndex) {
+              if (commOperandA->ancestorQueue.empty())
+                return true;
+              commOperandA->popFrontAndPushAdjacentUnvisitedAncestors();
+              commOperandA->refreshKey();
+            }
+            if (commOperandB->key.size() <= keyIndex) {
+              if (commOperandB->ancestorQueue.empty())
+                return false;
+              commOperandB->popFrontAndPushAdjacentUnvisitedAncestors();
+              commOperandB->refreshKey();
+            }
+            if (commOperandA->ancestorQueue.empty() ||
+                commOperandB->ancestorQueue.empty())
+              return commOperandA->key.size() < commOperandB->key.size();
+            if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex])
+              return true;
+            if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex])
+              return false;
+            keyIndex++;
+          }
+        };
+
+    // If `op` is not commutative, do nothing.
+    if (!op->hasTrait<OpTrait::IsCommutative>())
+      return failure();
+
+    // Populate the list of commutative operands.
+    SmallVector<Value, 2> operands = op->getOperands();
+    SmallVector<std::unique_ptr<CommutativeOperand>, 2> commOperands;
+    for (Value operand : operands) {
+      std::unique_ptr<CommutativeOperand> commOperand =
+          std::make_unique<CommutativeOperand>();
+      commOperand->operand = operand;
+      commOperand->pushAncestor(operand.getDefiningOp());
+      commOperand->refreshKey();
+      commOperands.push_back(std::move(commOperand));
+    }
+
+    // Sort the operands.
+    std::stable_sort(commOperands.begin(), commOperands.end(),
+                     commutativeOperandComparator);
+    SmallVector<Value, 2> sortedOperands;
+    for (const std::unique_ptr<CommutativeOperand> &commOperand : commOperands)
+      sortedOperands.push_back(commOperand->operand);
+    if (sortedOperands == operands)
+      return failure();
+    rewriter.updateRootInPlace(op, [&] { op->setOperands(sortedOperands); });
+    return success();
+  }
+};
+
+void mlir::populateCommutativityUtilsPatterns(RewritePatternSet &patterns) {
+  patterns.add<SortCommutativeOperands>(patterns.getContext());
+}

diff  --git a/mlir/test/Transforms/test-commutativity-utils.mlir b/mlir/test/Transforms/test-commutativity-utils.mlir
new file mode 100644
index 0000000000000..c544f165990a2
--- /dev/null
+++ b/mlir/test/Transforms/test-commutativity-utils.mlir
@@ -0,0 +1,116 @@
+// RUN: mlir-opt %s -test-commutativity-utils | FileCheck %s
+
+// CHECK-LABEL: @test_small_pattern_1
+func.func @test_small_pattern_1(%arg0 : i32) -> i32 {
+  // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant
+  %0 = arith.constant 45 : i32
+
+  // CHECK-NEXT: %[[TEST_ADD:.*]] = "test.addi"
+  %1 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
+
+  // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi
+  %2 = arith.addi %arg0, %arg0 : i32
+
+  // CHECK-NEXT: %[[ARITH_MUL:.*]] = arith.muli
+  %3 = arith.muli %arg0, %arg0 : i32
+
+  // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARITH_ADD]], %[[ARITH_MUL]], %[[TEST_ADD]], %[[ARITH_CONST]])
+  %result = "test.op_commutative"(%0, %1, %2, %3): (i32, i32, i32, i32) -> i32
+
+  // CHECK-NEXT: return %[[RESULT]]
+  return %result : i32
+}
+
+// CHECK-LABEL: @test_small_pattern_2
+// CHECK-SAME: (%[[ARG0:.*]]: i32
+func.func @test_small_pattern_2(%arg0 : i32) -> i32 {
+  // CHECK-NEXT: %[[TEST_CONST:.*]] = "test.constant"
+  %0 = "test.constant"() {value = 0 : i32} : () -> i32
+
+  // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant
+  %1 = arith.constant 0 : i32
+
+  // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi
+  %2 = arith.addi %arg0, %arg0 : i32
+
+  // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARG0]], %[[ARITH_ADD]], %[[ARITH_CONST]], %[[TEST_CONST]])
+  %result = "test.op_commutative"(%0, %1, %2, %arg0): (i32, i32, i32, i32) -> i32
+
+  // CHECK-NEXT: return %[[RESULT]]
+  return %result : i32
+}
+
+// CHECK-LABEL: @test_large_pattern
+func.func @test_large_pattern(%arg0 : i32, %arg1 : i32) -> i32 {
+  // CHECK-NEXT: arith.divsi
+  %0 = arith.divsi %arg0, %arg1 : i32
+
+  // CHECK-NEXT: arith.divsi
+  %1 = arith.divsi %0, %arg0 : i32
+
+  // CHECK-NEXT: arith.divsi
+  %2 = arith.divsi %1, %arg1 : i32
+
+  // CHECK-NEXT: arith.addi
+  %3 = arith.addi %1, %arg1 : i32
+
+  // CHECK-NEXT: arith.subi
+  %4 = arith.subi %2, %3 : i32
+
+  // CHECK-NEXT: "test.addi"
+  %5 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
+
+  // CHECK-NEXT: %[[VAL6:.*]] = arith.divsi
+  %6 = arith.divsi %4, %5 : i32
+
+  // CHECK-NEXT: arith.divsi
+  %7 = arith.divsi %1, %arg1 : i32
+
+  // CHECK-NEXT: %[[VAL8:.*]] = arith.muli
+  %8 = arith.muli %1, %arg1 : i32
+
+  // CHECK-NEXT: %[[VAL9:.*]] = arith.subi
+  %9 = arith.subi %7, %8 : i32
+
+  // CHECK-NEXT: "test.addi"
+  %10 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
+
+  // CHECK-NEXT: %[[VAL11:.*]] = arith.divsi
+  %11 = arith.divsi %9, %10 : i32
+
+  // CHECK-NEXT: %[[VAL12:.*]] = arith.divsi
+  %12 = arith.divsi %6, %arg1 : i32
+
+  // CHECK-NEXT: arith.subi
+  %13 = arith.subi %arg1, %arg0 : i32
+
+  // CHECK-NEXT: "test.op_commutative"(%[[VAL12]], %[[VAL12]], %[[VAL8]], %[[VAL9]])
+  %14 = "test.op_commutative"(%12, %9, %12, %8): (i32, i32, i32, i32) -> i32
+
+  // CHECK-NEXT: %[[VAL15:.*]] = arith.divsi
+  %15 = arith.divsi %13, %14 : i32
+
+  // CHECK-NEXT: %[[VAL16:.*]] = arith.addi
+  %16 = arith.addi %2, %15 : i32
+
+  // CHECK-NEXT: arith.subi
+  %17 = arith.subi %16, %arg1 : i32
+
+  // CHECK-NEXT: "test.addi"
+  %18 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32
+
+  // CHECK-NEXT: %[[VAL19:.*]] = arith.divsi
+  %19 = arith.divsi %17, %18 : i32
+
+  // CHECK-NEXT: "test.addi"
+  %20 = "test.addi"(%arg0, %16): (i32, i32) -> i32
+
+  // CHECK-NEXT: %[[VAL21:.*]] = arith.divsi
+  %21 = arith.divsi %17, %20 : i32
+
+  // CHECK-NEXT: %[[RESULT:.*]] = "test.op_large_commutative"(%[[VAL16]], %[[VAL19]], %[[VAL19]], %[[VAL21]], %[[VAL6]], %[[VAL11]], %[[VAL15]])
+  %result = "test.op_large_commutative"(%16, %6, %11, %15, %19, %21, %19): (i32, i32, i32, i32, i32, i32, i32) -> i32
+
+  // CHECK-NEXT: return %[[RESULT]]
+  return %result : i32
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 038f49aeed456..57a5376285138 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1186,11 +1186,21 @@ def TestOpWithVariadicResultsAndFolder: TEST_Op<"op_with_variadic_results_and_fo
   let hasFolder = 1;
 }
 
+def TestAddIOp : TEST_Op<"addi"> {
+  let arguments = (ins I32:$op1, I32:$op2);
+  let results = (outs I32);
+}
+
 def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> {
   let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4);
   let results = (outs I32);
 }
 
+def TestLargeCommutativeOp : TEST_Op<"op_large_commutative", [Commutative]> {
+  let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4, I32:$op5, I32:$op6, I32:$op7);
+  let results = (outs I32);
+}
+
 def TestCommutative2Op : TEST_Op<"op_commutative2", [Commutative]> {
   let arguments = (ins I32:$op1, I32:$op2);
   let results = (outs I32);

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 00856562d8709..8ce5b982fefb9 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRTestTransforms
+  TestCommutativityUtils.cpp
   TestConstantFold.cpp
   TestControlFlowSink.cpp
   TestInlining.cpp

diff  --git a/mlir/test/lib/Transforms/TestCommutativityUtils.cpp b/mlir/test/lib/Transforms/TestCommutativityUtils.cpp
new file mode 100644
index 0000000000000..2ec0334ae0d05
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestCommutativityUtils.cpp
@@ -0,0 +1,48 @@
+//===- TestCommutativityUtils.cpp - Pass to test the commutativity utility-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass tests the functionality of the commutativity utility pattern.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/CommutativityUtils.h"
+
+#include "TestDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+struct CommutativityUtils
+    : public PassWrapper<CommutativityUtils, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CommutativityUtils)
+
+  StringRef getArgument() const final { return "test-commutativity-utils"; }
+  StringRef getDescription() const final {
+    return "Test the functionality of the commutativity utility";
+  }
+
+  void runOnOperation() override {
+    auto func = getOperation();
+    auto *context = &getContext();
+
+    RewritePatternSet patterns(context);
+    populateCommutativityUtilsPatterns(patterns);
+
+    (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerCommutativityUtils() { PassRegistration<CommutativityUtils>(); }
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 78e26de1d54f0..3d48ec2987ac6 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -57,6 +57,7 @@ void registerTosaTestQuantUtilAPIPass();
 void registerVectorizerTestPass();
 
 namespace test {
+void registerCommutativityUtils();
 void registerConvertCallOpPass();
 void registerInliner();
 void registerMemRefBoundCheck();
@@ -149,6 +150,7 @@ void registerTestPasses() {
   registerVectorizerTestPass();
   registerTosaTestQuantUtilAPIPass();
 
+  mlir::test::registerCommutativityUtils();
   mlir::test::registerConvertCallOpPass();
   mlir::test::registerInliner();
   mlir::test::registerMemRefBoundCheck();


        


More information about the Mlir-commits mailing list