[Mlir-commits] [mlir] [MLIR] Add folding constants canonicalization for mlir::index::AddOp. (PR #111084)
weiwei chen
llvmlistbot at llvm.org
Tue Oct 8 07:41:37 PDT 2024
https://github.com/weiweichen updated https://github.com/llvm/llvm-project/pull/111084
>From 8f90d3855f444e7f701baae9e763b824e0060c60 Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Thu, 3 Oct 2024 22:45:59 -0400
Subject: [PATCH 1/2] Add folding constants canonicalization for
mlir::index::AddOp.
---
.../include/mlir/Dialect/Index/IR/IndexOps.td | 2 +
mlir/lib/Dialect/Index/IR/IndexOps.cpp | 39 +++++++++++++++++++
.../Dialect/Index/index-canonicalize.mlir | 15 +++++++
3 files changed, 56 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index a30ae9f739cbc6..ce1355316b09b8 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -56,6 +56,8 @@ def Index_AddOp : IndexBinaryOp<"add", [Commutative, Pure]> {
%c = index.add %a, %b
```
}];
+
+ let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index 42401dae217ce1..ace9b43014a665 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -136,6 +136,45 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
return {};
}
+/// Canonicalize
+/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = v + c1; y = c2 + x` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = c2 + x` to `x = v + (c1 + c2)`
+LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
+
+ auto matchConstant = [](mlir::index::AddOp op, Value &v, IntegerAttr &c) {
+ v = op.getLhs();
+ if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c))) {
+ v = op.getRhs();
+ if (!mlir::matchPattern(op.getLhs(), mlir::m_Constant(&c)))
+ return false;
+ }
+ return true;
+ };
+
+ IntegerAttr c1, c2;
+ Value v1, v2;
+
+ if (!matchConstant(op, v1, c1))
+ return rewriter.notifyMatchFailure(op.getLoc(),
+ "neither LHS nor RHS is constant");
+
+ auto add = v1.getDefiningOp<mlir::index::AddOp>();
+ if (!add)
+ return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
+
+ if (!matchConstant(add, v2, c2))
+ return rewriter.notifyMatchFailure(op.getLoc(),
+ "neither LHS nor RHS is constant");
+
+ auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
+ c1.getInt() + c2.getInt());
+ auto newAdd = rewriter.create<mlir::index::AddOp>(op->getLoc(), v2, c);
+
+ rewriter.replaceOp(op, newAdd);
+ return success();
+}
//===----------------------------------------------------------------------===//
// SubOp
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 37aa33bfde952e..256e327e83ea9c 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -32,6 +32,21 @@ func.func @add_overflow() -> (index, index) {
return %2, %3 : index, index
}
+// CHECK-LABEL: @add
+func.func @add_fold_constants(%arg: index) -> (index) {
+ %0 = index.constant 1
+ %1 = index.constant 2
+ %2 = index.add %arg, %0
+ %3 = index.add %1, %2
+ %4 = index.add %3, %1
+ %5 = index.add %4, %0
+
+ // CHECK-DAG: [[A:%.*]] = index.constant 6
+ // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[A]]
+ // CHECK: return [[V0]]
+ return %5 : index
+}
+
// CHECK-LABEL: @sub
func.func @sub() -> index {
%0 = index.constant -2000000000
>From c755012c9d6bf7af7fd8521b50ec9b1eab471604 Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Tue, 8 Oct 2024 10:39:56 -0400
Subject: [PATCH 2/2] Address review comments.
---
mlir/lib/Dialect/Index/IR/IndexOps.cpp | 31 +++++--------------
.../Dialect/Index/index-canonicalize.mlir | 10 +++---
2 files changed, 11 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index ace9b43014a665..dbc63d9d107587 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -138,39 +138,22 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
}
/// Canonicalize
/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
-/// ` x = v + c1; y = c2 + x` to `x = v + (c1 + c2)`
-/// ` x = c1 + v; y = x + c2` to `x = v + (c1 + c2)`
-/// ` x = c1 + v; y = c2 + x` to `x = v + (c1 + c2)`
LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
-
- auto matchConstant = [](mlir::index::AddOp op, Value &v, IntegerAttr &c) {
- v = op.getLhs();
- if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c))) {
- v = op.getRhs();
- if (!mlir::matchPattern(op.getLhs(), mlir::m_Constant(&c)))
- return false;
- }
- return true;
- };
-
IntegerAttr c1, c2;
- Value v1, v2;
+ if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c1)))
+ return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
- if (!matchConstant(op, v1, c1))
- return rewriter.notifyMatchFailure(op.getLoc(),
- "neither LHS nor RHS is constant");
-
- auto add = v1.getDefiningOp<mlir::index::AddOp>();
+ auto add = op.getLhs().getDefiningOp<mlir::index::AddOp>();
if (!add)
return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
- if (!matchConstant(add, v2, c2))
- return rewriter.notifyMatchFailure(op.getLoc(),
- "neither LHS nor RHS is constant");
+ if (!mlir::matchPattern(add.getRhs(), mlir::m_Constant(&c2)))
+ return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
c1.getInt() + c2.getInt());
- auto newAdd = rewriter.create<mlir::index::AddOp>(op->getLoc(), v2, c);
+ auto newAdd =
+ rewriter.create<mlir::index::AddOp>(op->getLoc(), add.getLhs(), c);
rewriter.replaceOp(op, newAdd);
return success();
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 256e327e83ea9c..a29b09c11f7f62 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -37,14 +37,12 @@ func.func @add_fold_constants(%arg: index) -> (index) {
%0 = index.constant 1
%1 = index.constant 2
%2 = index.add %arg, %0
- %3 = index.add %1, %2
- %4 = index.add %3, %1
- %5 = index.add %4, %0
+ %3 = index.add %2, %1
- // CHECK-DAG: [[A:%.*]] = index.constant 6
- // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[A]]
+ // CHECK-DAG: [[C3:%.*]] = index.constant 3
+ // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[C3]]
// CHECK: return [[V0]]
- return %5 : index
+ return %3 : index
}
// CHECK-LABEL: @sub
More information about the Mlir-commits
mailing list