[Mlir-commits] [mlir] [MLIR] Add folding constants canonicalization for mlir::index::AddOp. (PR #111084)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 3 19:48:18 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-index
Author: weiwei chen (weiweichen)
<details>
<summary>Changes</summary>
- [x] Add a simple canonicalization for `mlir::index::AddOp`.
---
Full diff: https://github.com/llvm/llvm-project/pull/111084.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Index/IR/IndexOps.td (+2)
- (modified) mlir/lib/Dialect/Index/IR/IndexOps.cpp (+39)
- (modified) mlir/test/Dialect/Index/index-canonicalize.mlir (+15)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index a30ae9f739cbc6..ce1355316b09b8 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -56,6 +56,8 @@ def Index_AddOp : IndexBinaryOp<"add", [Commutative, Pure]> {
%c = index.add %a, %b
```
}];
+
+ let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index 42401dae217ce1..ace9b43014a665 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -136,6 +136,45 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
return {};
}
+/// Canonicalize
+/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = v + c1; y = c2 + x` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = c2 + x` to `x = v + (c1 + c2)`
+LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
+
+ auto matchConstant = [](mlir::index::AddOp op, Value &v, IntegerAttr &c) {
+ v = op.getLhs();
+ if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c))) {
+ v = op.getRhs();
+ if (!mlir::matchPattern(op.getLhs(), mlir::m_Constant(&c)))
+ return false;
+ }
+ return true;
+ };
+
+ IntegerAttr c1, c2;
+ Value v1, v2;
+
+ if (!matchConstant(op, v1, c1))
+ return rewriter.notifyMatchFailure(op.getLoc(),
+ "neither LHS nor RHS is constant");
+
+ auto add = v1.getDefiningOp<mlir::index::AddOp>();
+ if (!add)
+ return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
+
+ if (!matchConstant(add, v2, c2))
+ return rewriter.notifyMatchFailure(op.getLoc(),
+ "neither LHS nor RHS is constant");
+
+ auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
+ c1.getInt() + c2.getInt());
+ auto newAdd = rewriter.create<mlir::index::AddOp>(op->getLoc(), v2, c);
+
+ rewriter.replaceOp(op, newAdd);
+ return success();
+}
//===----------------------------------------------------------------------===//
// SubOp
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 37aa33bfde952e..256e327e83ea9c 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -32,6 +32,21 @@ func.func @add_overflow() -> (index, index) {
return %2, %3 : index, index
}
+// CHECK-LABEL: @add
+func.func @add_fold_constants(%arg: index) -> (index) {
+ %0 = index.constant 1
+ %1 = index.constant 2
+ %2 = index.add %arg, %0
+ %3 = index.add %1, %2
+ %4 = index.add %3, %1
+ %5 = index.add %4, %0
+
+ // CHECK-DAG: [[A:%.*]] = index.constant 6
+ // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[A]]
+ // CHECK: return [[V0]]
+ return %5 : index
+}
+
// CHECK-LABEL: @sub
func.func @sub() -> index {
%0 = index.constant -2000000000
``````````
</details>
https://github.com/llvm/llvm-project/pull/111084
More information about the Mlir-commits
mailing list