[Mlir-commits] [mlir] [MLIR] Add folding constants canonicalization for mlir::index::AddOp. (PR #111084)

weiwei chen llvmlistbot at llvm.org
Tue Oct 8 07:41:37 PDT 2024


https://github.com/weiweichen updated https://github.com/llvm/llvm-project/pull/111084

>From 8f90d3855f444e7f701baae9e763b824e0060c60 Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Thu, 3 Oct 2024 22:45:59 -0400
Subject: [PATCH 1/2] Add folding constants canonicalization for
 mlir::index::AddOp.

---
 .../include/mlir/Dialect/Index/IR/IndexOps.td |  2 +
 mlir/lib/Dialect/Index/IR/IndexOps.cpp        | 39 +++++++++++++++++++
 .../Dialect/Index/index-canonicalize.mlir     | 15 +++++++
 3 files changed, 56 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index a30ae9f739cbc6..ce1355316b09b8 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -56,6 +56,8 @@ def Index_AddOp : IndexBinaryOp<"add", [Commutative, Pure]> {
     %c = index.add %a, %b
     ```
   }];
+
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index 42401dae217ce1..ace9b43014a665 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -136,6 +136,45 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
 
   return {};
 }
+/// Canonicalize
+/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = v + c1; y = c2 + x` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = c2 + x` to `x = v + (c1 + c2)`
+LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
+
+  auto matchConstant = [](mlir::index::AddOp op, Value &v, IntegerAttr &c) {
+    v = op.getLhs();
+    if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c))) {
+      v = op.getRhs();
+      if (!mlir::matchPattern(op.getLhs(), mlir::m_Constant(&c)))
+        return false;
+    }
+    return true;
+  };
+
+  IntegerAttr c1, c2;
+  Value v1, v2;
+
+  if (!matchConstant(op, v1, c1))
+    return rewriter.notifyMatchFailure(op.getLoc(),
+                                       "neither LHS nor RHS is constant");
+
+  auto add = v1.getDefiningOp<mlir::index::AddOp>();
+  if (!add)
+    return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
+
+  if (!matchConstant(add, v2, c2))
+    return rewriter.notifyMatchFailure(op.getLoc(),
+                                       "neither LHS nor RHS is constant");
+
+  auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
+                                                    c1.getInt() + c2.getInt());
+  auto newAdd = rewriter.create<mlir::index::AddOp>(op->getLoc(), v2, c);
+
+  rewriter.replaceOp(op, newAdd);
+  return success();
+}
 
 //===----------------------------------------------------------------------===//
 // SubOp
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 37aa33bfde952e..256e327e83ea9c 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -32,6 +32,21 @@ func.func @add_overflow() -> (index, index) {
   return %2, %3 : index, index
 }
 
+// CHECK-LABEL: @add
+func.func @add_fold_constants(%arg: index) -> (index) {
+  %0 = index.constant 1
+  %1 = index.constant 2
+  %2 = index.add %arg, %0
+  %3 = index.add %1, %2
+  %4 = index.add %3, %1
+  %5 = index.add %4, %0
+
+  // CHECK-DAG: [[A:%.*]] = index.constant 6
+  // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[A]]
+  // CHECK: return [[V0]]
+  return %5 : index
+}
+
 // CHECK-LABEL: @sub
 func.func @sub() -> index {
   %0 = index.constant -2000000000

>From c755012c9d6bf7af7fd8521b50ec9b1eab471604 Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Tue, 8 Oct 2024 10:39:56 -0400
Subject: [PATCH 2/2] Address review comments.

---
 mlir/lib/Dialect/Index/IR/IndexOps.cpp        | 31 +++++--------------
 .../Dialect/Index/index-canonicalize.mlir     | 10 +++---
 2 files changed, 11 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index ace9b43014a665..dbc63d9d107587 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -138,39 +138,22 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
 }
 /// Canonicalize
 /// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
-/// ` x = v + c1; y = c2 + x` to `x = v + (c1 + c2)`
-/// ` x = c1 + v; y = x + c2` to `x = v + (c1 + c2)`
-/// ` x = c1 + v; y = c2 + x` to `x = v + (c1 + c2)`
 LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
-
-  auto matchConstant = [](mlir::index::AddOp op, Value &v, IntegerAttr &c) {
-    v = op.getLhs();
-    if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c))) {
-      v = op.getRhs();
-      if (!mlir::matchPattern(op.getLhs(), mlir::m_Constant(&c)))
-        return false;
-    }
-    return true;
-  };
-
   IntegerAttr c1, c2;
-  Value v1, v2;
+  if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c1)))
+    return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
 
-  if (!matchConstant(op, v1, c1))
-    return rewriter.notifyMatchFailure(op.getLoc(),
-                                       "neither LHS nor RHS is constant");
-
-  auto add = v1.getDefiningOp<mlir::index::AddOp>();
+  auto add = op.getLhs().getDefiningOp<mlir::index::AddOp>();
   if (!add)
     return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
 
-  if (!matchConstant(add, v2, c2))
-    return rewriter.notifyMatchFailure(op.getLoc(),
-                                       "neither LHS nor RHS is constant");
+  if (!mlir::matchPattern(add.getRhs(), mlir::m_Constant(&c2)))
+    return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
 
   auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
                                                     c1.getInt() + c2.getInt());
-  auto newAdd = rewriter.create<mlir::index::AddOp>(op->getLoc(), v2, c);
+  auto newAdd =
+      rewriter.create<mlir::index::AddOp>(op->getLoc(), add.getLhs(), c);
 
   rewriter.replaceOp(op, newAdd);
   return success();
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 256e327e83ea9c..a29b09c11f7f62 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -37,14 +37,12 @@ func.func @add_fold_constants(%arg: index) -> (index) {
   %0 = index.constant 1
   %1 = index.constant 2
   %2 = index.add %arg, %0
-  %3 = index.add %1, %2
-  %4 = index.add %3, %1
-  %5 = index.add %4, %0
+  %3 = index.add %2, %1
 
-  // CHECK-DAG: [[A:%.*]] = index.constant 6
-  // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[A]]
+  // CHECK-DAG: [[C3:%.*]] = index.constant 3
+  // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[C3]]
   // CHECK: return [[V0]]
-  return %5 : index
+  return %3 : index
 }
 
 // CHECK-LABEL: @sub



More information about the Mlir-commits mailing list