[Mlir-commits] [mlir] 7191ced - [MLIR] Add folding constants canonicalization for mlir::index::AddOp. (#111084)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 22 12:04:29 PDT 2024
Author: weiwei chen
Date: 2024-10-22T12:04:26-07:00
New Revision: 7191ced3b69e6f4f0e67056be416e399d0a8d7ca
URL: https://github.com/llvm/llvm-project/commit/7191ced3b69e6f4f0e67056be416e399d0a8d7ca
DIFF: https://github.com/llvm/llvm-project/commit/7191ced3b69e6f4f0e67056be416e399d0a8d7ca.diff
LOG: [MLIR] Add folding constants canonicalization for mlir::index::AddOp. (#111084)
- [x] Add a simple canonicalization for `mlir::index::AddOp`.
Added:
Modified:
mlir/include/mlir/Dialect/Index/IR/IndexOps.td
mlir/lib/Dialect/Index/IR/IndexOps.cpp
mlir/test/Dialect/Index/index-canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index a30ae9f739cbc6..ce1355316b09b8 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -56,6 +56,8 @@ def Index_AddOp : IndexBinaryOp<"add", [Commutative, Pure]> {
%c = index.add %a, %b
```
}];
+
+ let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index 5ad989b7da126e..0b58eb80f93032 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -136,6 +136,28 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
return {};
}
+/// Canonicalize
+/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
+LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
+ IntegerAttr c1, c2;
+ if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c1)))
+ return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
+
+ auto add = op.getLhs().getDefiningOp<mlir::index::AddOp>();
+ if (!add)
+ return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
+
+ if (!mlir::matchPattern(add.getRhs(), mlir::m_Constant(&c2)))
+ return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
+
+ auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
+ c1.getInt() + c2.getInt());
+ auto newAdd =
+ rewriter.create<mlir::index::AddOp>(op->getLoc(), add.getLhs(), c);
+
+ rewriter.replaceOp(op, newAdd);
+ return success();
+}
//===----------------------------------------------------------------------===//
// SubOp
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 37aa33bfde952e..a29b09c11f7f62 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -32,6 +32,19 @@ func.func @add_overflow() -> (index, index) {
return %2, %3 : index, index
}
+// CHECK-LABEL: @add
+func.func @add_fold_constants(%arg: index) -> (index) {
+ %0 = index.constant 1
+ %1 = index.constant 2
+ %2 = index.add %arg, %0
+ %3 = index.add %2, %1
+
+ // CHECK-DAG: [[C3:%.*]] = index.constant 3
+ // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[C3]]
+ // CHECK: return [[V0]]
+ return %3 : index
+}
+
// CHECK-LABEL: @sub
func.func @sub() -> index {
%0 = index.constant -2000000000
More information about the Mlir-commits
mailing list