[Mlir-commits] [mlir] 7191ced - [MLIR] Add folding constants canonicalization for mlir::index::AddOp. (#111084)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 22 12:04:29 PDT 2024


Author: weiwei chen
Date: 2024-10-22T12:04:26-07:00
New Revision: 7191ced3b69e6f4f0e67056be416e399d0a8d7ca

URL: https://github.com/llvm/llvm-project/commit/7191ced3b69e6f4f0e67056be416e399d0a8d7ca
DIFF: https://github.com/llvm/llvm-project/commit/7191ced3b69e6f4f0e67056be416e399d0a8d7ca.diff

LOG: [MLIR] Add folding constants canonicalization for mlir::index::AddOp. (#111084)

- [x] Add a simple canonicalization for `mlir::index::AddOp`.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Index/IR/IndexOps.td
    mlir/lib/Dialect/Index/IR/IndexOps.cpp
    mlir/test/Dialect/Index/index-canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index a30ae9f739cbc6..ce1355316b09b8 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -56,6 +56,8 @@ def Index_AddOp : IndexBinaryOp<"add", [Commutative, Pure]> {
     %c = index.add %a, %b
     ```
   }];
+
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index 5ad989b7da126e..0b58eb80f93032 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -136,6 +136,28 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
 
   return {};
 }
+/// Canonicalize
+/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
+LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
+  IntegerAttr c1, c2;
+  if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c1)))
+    return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
+
+  auto add = op.getLhs().getDefiningOp<mlir::index::AddOp>();
+  if (!add)
+    return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
+
+  if (!mlir::matchPattern(add.getRhs(), mlir::m_Constant(&c2)))
+    return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
+
+  auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
+                                                    c1.getInt() + c2.getInt());
+  auto newAdd =
+      rewriter.create<mlir::index::AddOp>(op->getLoc(), add.getLhs(), c);
+
+  rewriter.replaceOp(op, newAdd);
+  return success();
+}
 
 //===----------------------------------------------------------------------===//
 // SubOp

diff  --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 37aa33bfde952e..a29b09c11f7f62 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -32,6 +32,19 @@ func.func @add_overflow() -> (index, index) {
   return %2, %3 : index, index
 }
 
+// CHECK-LABEL: @add
+func.func @add_fold_constants(%arg: index) -> (index) {
+  %0 = index.constant 1
+  %1 = index.constant 2
+  %2 = index.add %arg, %0
+  %3 = index.add %2, %1
+
+  // CHECK-DAG: [[C3:%.*]] = index.constant 3
+  // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[C3]]
+  // CHECK: return [[V0]]
+  return %3 : index
+}
+
 // CHECK-LABEL: @sub
 func.func @sub() -> index {
   %0 = index.constant -2000000000


        


More information about the Mlir-commits mailing list