[Mlir-commits] [mlir] [MLIR] Add canonicalizations to all eligible `index` binary ops (PR #114000)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 28 21:23:20 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Nachi G (nacgarg)

<details>
<summary>Changes</summary>

Generalizes the following canonicalization pattern to all associative and commutative binary ops in the `index` dialect.

```
x = v + c1
y = x + c2
   -->
y = x + (c1 + c2)
```

This includes:
- `AddOp`
- `MulOp`
- `MaxSOp`
- `MaxUOp`
- `MinSOp`
- `MinUOp`
- `AndOp`
- `OrOp`
- `XOrOp`

The operation folding is implemented using the existing folders since `createAndFold` is used in the canonicalization.

---
Full diff: https://github.com/llvm/llvm-project/pull/114000.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Index/IR/IndexOps.td (+16) 
- (modified) mlir/lib/Dialect/Index/IR/IndexOps.cpp (+59-20) 
- (modified) mlir/test/Dialect/Index/index-canonicalize.mlir (+105-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index ce1355316b09b8..230a3815bdd81e 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -95,6 +95,8 @@ def Index_MulOp : IndexBinaryOp<"mul", [Commutative, Pure]> {
     %c = index.mul %a, %b
     ```
   }];
+
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -263,6 +265,8 @@ def Index_MaxSOp : IndexBinaryOp<"maxs", [Commutative, Pure]> {
     %c = index.maxs %a, %b
     ```
   }];
+
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -283,6 +287,8 @@ def Index_MaxUOp : IndexBinaryOp<"maxu", [Commutative, Pure]> {
     %c = index.maxu %a, %b
     ```
   }];
+
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -302,6 +308,8 @@ def Index_MinSOp : IndexBinaryOp<"mins", [Commutative, Pure]> {
     %c = index.mins %a, %b
     ```
   }];
+
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -322,6 +330,8 @@ def Index_MinUOp : IndexBinaryOp<"minu", [Commutative, Pure]> {
     %c = index.minu %a, %b
     ```
   }];
+
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -404,6 +414,8 @@ def Index_AndOp : IndexBinaryOp<"and", [Commutative, Pure]> {
     %c = index.and %a, %b
     ```
   }];
+
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -423,6 +435,8 @@ def Index_OrOp : IndexBinaryOp<"or", [Commutative, Pure]> {
     %c = index.or %a, %b
     ```
   }];
+
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -442,6 +456,8 @@ def Index_XOrOp : IndexBinaryOp<"xor", [Commutative, Pure]> {
     %c = index.xor %a, %b
     ```
   }];
+
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index 0b58eb80f93032..a650b05767bbb7 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -118,6 +118,31 @@ static OpFoldResult foldBinaryOpChecked(
   return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64);
 }
 
+/// Helper for associative and commutative binary ops that can be transformed:
+/// `x = op(v, c1); y = op(x, c2)` -> `tmp = op(c1, c2); y = op(v, tmp)`
+/// where c1 and c2 are constants. It is expected that `tmp` will be folded.
+template <typename BinaryOp>
+static LogicalResult
+canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op,
+                                           PatternRewriter &rewriter) {
+  IntegerAttr c1, c2;
+  if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c1)))
+    return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
+
+  auto lhsOp = op.getLhs().template getDefiningOp<BinaryOp>();
+  if (!lhsOp)
+    return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
+
+  if (!mlir::matchPattern(lhsOp.getRhs(), mlir::m_Constant(&c2)))
+    return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
+
+  auto c = rewriter.createOrFold<BinaryOp>(op->getLoc(), op.getRhs(),
+                                           lhsOp.getRhs());
+
+  rewriter.replaceOpWithNewOp<BinaryOp>(op, lhsOp.getLhs(), c);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // AddOp
 //===----------------------------------------------------------------------===//
@@ -136,27 +161,9 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
 
   return {};
 }
-/// Canonicalize
-/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
-LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
-  IntegerAttr c1, c2;
-  if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c1)))
-    return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
-
-  auto add = op.getLhs().getDefiningOp<mlir::index::AddOp>();
-  if (!add)
-    return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
-
-  if (!mlir::matchPattern(add.getRhs(), mlir::m_Constant(&c2)))
-    return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
-
-  auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
-                                                    c1.getInt() + c2.getInt());
-  auto newAdd =
-      rewriter.create<mlir::index::AddOp>(op->getLoc(), add.getLhs(), c);
 
-  rewriter.replaceOp(op, newAdd);
-  return success();
+LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
+  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
 }
 
 //===----------------------------------------------------------------------===//
@@ -200,6 +207,10 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
+  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
+}
+
 //===----------------------------------------------------------------------===//
 // DivSOp
 //===----------------------------------------------------------------------===//
@@ -352,6 +363,10 @@ OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
                              });
 }
 
+LogicalResult MaxSOp::canonicalize(MaxSOp op, PatternRewriter &rewriter) {
+  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
+}
+
 //===----------------------------------------------------------------------===//
 // MaxUOp
 //===----------------------------------------------------------------------===//
@@ -363,6 +378,10 @@ OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
                              });
 }
 
+LogicalResult MaxUOp::canonicalize(MaxUOp op, PatternRewriter &rewriter) {
+  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
+}
+
 //===----------------------------------------------------------------------===//
 // MinSOp
 //===----------------------------------------------------------------------===//
@@ -374,6 +393,10 @@ OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
                              });
 }
 
+LogicalResult MinSOp::canonicalize(MinSOp op, PatternRewriter &rewriter) {
+  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
+}
+
 //===----------------------------------------------------------------------===//
 // MinUOp
 //===----------------------------------------------------------------------===//
@@ -385,6 +408,10 @@ OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
                              });
 }
 
+LogicalResult MinUOp::canonicalize(MinUOp op, PatternRewriter &rewriter) {
+  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
+}
+
 //===----------------------------------------------------------------------===//
 // ShlOp
 //===----------------------------------------------------------------------===//
@@ -442,6 +469,10 @@ OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
       [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
 }
 
+LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
+  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
+}
+
 //===----------------------------------------------------------------------===//
 // OrOp
 //===----------------------------------------------------------------------===//
@@ -452,6 +483,10 @@ OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
       [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
 }
 
+LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
+  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
+}
+
 //===----------------------------------------------------------------------===//
 // XOrOp
 //===----------------------------------------------------------------------===//
@@ -462,6 +497,10 @@ OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
       [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
 }
 
+LogicalResult XOrOp::canonicalize(XOrOp op, PatternRewriter &rewriter) {
+  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
+}
+
 //===----------------------------------------------------------------------===//
 // CastSOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index a29b09c11f7f62..cecc29a13fc901 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -32,7 +32,7 @@ func.func @add_overflow() -> (index, index) {
   return %2, %3 : index, index
 }
 
-// CHECK-LABEL: @add
+// CHECK-LABEL: @add_fold_constants
 func.func @add_fold_constants(%arg: index) -> (index) {
   %0 = index.constant 1
   %1 = index.constant 2
@@ -65,6 +65,19 @@ func.func @mul() -> index {
   return %2 : index
 }
 
+// CHECK-LABEL: @mul_fold_constants
+func.func @mul_fold_constants(%arg: index) -> (index) {
+  %0 = index.constant 2
+  %1 = index.constant 3
+  %2 = index.mul %arg, %0
+  %3 = index.mul %2, %1
+
+  // CHECK-DAG: [[C6:%.*]] = index.constant 6
+  // CHECK-DAG: [[V0:%.*]] = index.mul %arg0, [[C6]]
+  // CHECK: return [[V0]]
+  return %3 : index
+}
+
 // CHECK-LABEL: @divs
 func.func @divs() -> index {
   %0 = index.constant -2
@@ -300,6 +313,19 @@ func.func @maxs_edge() -> index {
   return %0 : index
 }
 
+// CHECK-LABEL: @maxs_fold_constants
+func.func @maxs_fold_constants(%arg: index) -> (index) {
+  %0 = index.constant 2
+  %1 = index.constant 3
+  %2 = index.maxs %arg, %0
+  %3 = index.maxs %2, %1
+
+  // CHECK-DAG: [[C3:%.*]] = index.constant 3
+  // CHECK-DAG: [[V0:%.*]] = index.maxs %arg0, [[C3]]
+  // CHECK: return [[V0]]
+  return %3 : index
+}
+
 // CHECK-LABEL: @maxu
 func.func @maxu() -> index {
   %lhs = index.constant -1
@@ -310,6 +336,19 @@ func.func @maxu() -> index {
   return %0 : index
 }
 
+// CHECK-LABEL: @maxu_fold_constants
+func.func @maxu_fold_constants(%arg: index) -> (index) {
+  %0 = index.constant 2
+  %1 = index.constant 3
+  %2 = index.maxu %arg, %0
+  %3 = index.maxu %2, %1
+
+  // CHECK-DAG: [[C3:%.*]] = index.constant 3
+  // CHECK-DAG: [[V0:%.*]] = index.maxu %arg0, [[C3]]
+  // CHECK: return [[V0]]
+  return %3 : index
+}
+
 // CHECK-LABEL: @mins
 func.func @mins() -> index {
   %lhs = index.constant -4
@@ -340,6 +379,19 @@ func.func @mins_nofold_2() -> index {
   return %0 : index
 }
 
+// CHECK-LABEL: @mins_fold_constants
+func.func @mins_fold_constants(%arg: index) -> (index) {
+  %0 = index.constant 2
+  %1 = index.constant 3
+  %2 = index.mins %arg, %0
+  %3 = index.mins %2, %1
+
+  // CHECK-DAG: [[C2:%.*]] = index.constant 2
+  // CHECK-DAG: [[V0:%.*]] = index.mins %arg0, [[C2]]
+  // CHECK: return [[V0]]
+  return %3 : index
+}
+
 // CHECK-LABEL: @minu
 func.func @minu() -> index {
   %lhs = index.constant -1
@@ -350,6 +402,19 @@ func.func @minu() -> index {
   return %0 : index
 }
 
+// CHECK-LABEL: @minu_fold_constants
+func.func @minu_fold_constants(%arg: index) -> (index) {
+  %0 = index.constant 2
+  %1 = index.constant 3
+  %2 = index.minu %arg, %0
+  %3 = index.minu %2, %1
+
+  // CHECK-DAG: [[C2:%.*]] = index.constant 2
+  // CHECK-DAG: [[V0:%.*]] = index.minu %arg0, [[C2]]
+  // CHECK: return [[V0]]
+  return %3 : index
+}
+
 // CHECK-LABEL: @shl
 func.func @shl() -> index {
   %lhs = index.constant 128
@@ -465,6 +530,19 @@ func.func @and() -> index {
   return %0 : index
 }
 
+// CHECK-LABEL: @and_fold_constants
+func.func @and_fold_constants(%arg: index) -> (index) {
+  %0 = index.constant 5
+  %1 = index.constant 1
+  %2 = index.and %arg, %0
+  %3 = index.and %2, %1
+
+  // CHECK-DAG: [[C1:%.*]] = index.constant 1
+  // CHECK-DAG: [[V0:%.*]] = index.and %arg0, [[C1]]
+  // CHECK: return [[V0]]
+  return %3 : index
+}
+
 // CHECK-LABEL: @or
 func.func @or() -> index {
   %lhs = index.constant 5
@@ -475,6 +553,19 @@ func.func @or() -> index {
   return %0 : index
 }
 
+// CHECK-LABEL: @or_fold_constants
+func.func @or_fold_constants(%arg: index) -> (index) {
+  %0 = index.constant 5
+  %1 = index.constant 1
+  %2 = index.or %arg, %0
+  %3 = index.or %2, %1
+
+  // CHECK-DAG: [[C5:%.*]] = index.constant 5
+  // CHECK-DAG: [[V0:%.*]] = index.or %arg0, [[C5]]
+  // CHECK: return [[V0]]
+  return %3 : index
+}
+
 // CHECK-LABEL: @xor
 func.func @xor() -> index {
   %lhs = index.constant 5
@@ -485,6 +576,19 @@ func.func @xor() -> index {
   return %0 : index
 }
 
+// CHECK-LABEL: @xor_fold_constants
+func.func @xor_fold_constants(%arg: index) -> (index) {
+  %0 = index.constant 5
+  %1 = index.constant 1
+  %2 = index.xor %arg, %0
+  %3 = index.xor %2, %1
+
+  // CHECK-DAG: [[C4:%.*]] = index.constant 4
+  // CHECK-DAG: [[V0:%.*]] = index.xor %arg0, [[C4]]
+  // CHECK: return [[V0]]
+  return %3 : index
+}
+
 // CHECK-LABEL: @cmp
 func.func @cmp(%arg0: index) -> (i1, i1, i1, i1, i1, i1) {
   %a = index.constant 0

``````````

</details>


https://github.com/llvm/llvm-project/pull/114000


More information about the Mlir-commits mailing list