[Mlir-commits] [mlir] [MLIR] Add folding constants canonicalization for mlir::index::AddOp. (PR #111084)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 3 19:48:19 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: weiwei chen (weiweichen)

<details>
<summary>Changes</summary>

- [x] Add a simple canonicalization for `mlir::index::AddOp`. 

---
Full diff: https://github.com/llvm/llvm-project/pull/111084.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Index/IR/IndexOps.td (+2) 
- (modified) mlir/lib/Dialect/Index/IR/IndexOps.cpp (+39) 
- (modified) mlir/test/Dialect/Index/index-canonicalize.mlir (+15) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index a30ae9f739cbc6..ce1355316b09b8 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -56,6 +56,8 @@ def Index_AddOp : IndexBinaryOp<"add", [Commutative, Pure]> {
     %c = index.add %a, %b
     ```
   }];
+
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index 42401dae217ce1..ace9b43014a665 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -136,6 +136,45 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
 
   return {};
 }
+/// Canonicalize
+/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = v + c1; y = c2 + x` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = c2 + x` to `x = v + (c1 + c2)`
+LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
+
+  auto matchConstant = [](mlir::index::AddOp op, Value &v, IntegerAttr &c) {
+    v = op.getLhs();
+    if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c))) {
+      v = op.getRhs();
+      if (!mlir::matchPattern(op.getLhs(), mlir::m_Constant(&c)))
+        return false;
+    }
+    return true;
+  };
+
+  IntegerAttr c1, c2;
+  Value v1, v2;
+
+  if (!matchConstant(op, v1, c1))
+    return rewriter.notifyMatchFailure(op.getLoc(),
+                                       "neither LHS nor RHS is constant");
+
+  auto add = v1.getDefiningOp<mlir::index::AddOp>();
+  if (!add)
+    return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
+
+  if (!matchConstant(add, v2, c2))
+    return rewriter.notifyMatchFailure(op.getLoc(),
+                                       "neither LHS nor RHS is constant");
+
+  auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
+                                                    c1.getInt() + c2.getInt());
+  auto newAdd = rewriter.create<mlir::index::AddOp>(op->getLoc(), v2, c);
+
+  rewriter.replaceOp(op, newAdd);
+  return success();
+}
 
 //===----------------------------------------------------------------------===//
 // SubOp
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 37aa33bfde952e..256e327e83ea9c 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -32,6 +32,21 @@ func.func @add_overflow() -> (index, index) {
   return %2, %3 : index, index
 }
 
+// CHECK-LABEL: @add
+func.func @add_fold_constants(%arg: index) -> (index) {
+  %0 = index.constant 1
+  %1 = index.constant 2
+  %2 = index.add %arg, %0
+  %3 = index.add %1, %2
+  %4 = index.add %3, %1
+  %5 = index.add %4, %0
+
+  // CHECK-DAG: [[A:%.*]] = index.constant 6
+  // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[A]]
+  // CHECK: return [[V0]]
+  return %5 : index
+}
+
 // CHECK-LABEL: @sub
 func.func @sub() -> index {
   %0 = index.constant -2000000000

``````````

</details>


https://github.com/llvm/llvm-project/pull/111084


More information about the Mlir-commits mailing list