[Mlir-commits] [mlir] 795b4ef - [MLIR] Add canonicalizations to all eligible `index` binary ops (#114000)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 4 13:24:26 PST 2024
Author: Nachi G
Date: 2024-11-04T13:24:22-08:00
New Revision: 795b4efad0259cbf03fc98e3045621916328ce57
URL: https://github.com/llvm/llvm-project/commit/795b4efad0259cbf03fc98e3045621916328ce57
DIFF: https://github.com/llvm/llvm-project/commit/795b4efad0259cbf03fc98e3045621916328ce57.diff
LOG: [MLIR] Add canonicalizations to all eligible `index` binary ops (#114000)
Generalizes the following canonicalization pattern to all associative
and commutative binary ops in the `index` dialect.
```
x = v + c1
y = x + c2
-->
y = x + (c1 + c2)
```
This includes:
- `AddOp`
- `MulOp`
- `MaxSOp`
- `MaxUOp`
- `MinSOp`
- `MinUOp`
- `AndOp`
- `OrOp`
- `XOrOp`
The operation folding is implemented using the existing folders since
`createAndFold` is used in the canonicalization.
Added:
Modified:
mlir/include/mlir/Dialect/Index/IR/IndexOps.td
mlir/lib/Dialect/Index/IR/IndexOps.cpp
mlir/test/Dialect/Index/index-canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index ce1355316b09b8..230a3815bdd81e 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -95,6 +95,8 @@ def Index_MulOp : IndexBinaryOp<"mul", [Commutative, Pure]> {
%c = index.mul %a, %b
```
}];
+
+ let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
@@ -263,6 +265,8 @@ def Index_MaxSOp : IndexBinaryOp<"maxs", [Commutative, Pure]> {
%c = index.maxs %a, %b
```
}];
+
+ let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
@@ -283,6 +287,8 @@ def Index_MaxUOp : IndexBinaryOp<"maxu", [Commutative, Pure]> {
%c = index.maxu %a, %b
```
}];
+
+ let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
@@ -302,6 +308,8 @@ def Index_MinSOp : IndexBinaryOp<"mins", [Commutative, Pure]> {
%c = index.mins %a, %b
```
}];
+
+ let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
@@ -322,6 +330,8 @@ def Index_MinUOp : IndexBinaryOp<"minu", [Commutative, Pure]> {
%c = index.minu %a, %b
```
}];
+
+ let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
@@ -404,6 +414,8 @@ def Index_AndOp : IndexBinaryOp<"and", [Commutative, Pure]> {
%c = index.and %a, %b
```
}];
+
+ let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
@@ -423,6 +435,8 @@ def Index_OrOp : IndexBinaryOp<"or", [Commutative, Pure]> {
%c = index.or %a, %b
```
}];
+
+ let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
@@ -442,6 +456,8 @@ def Index_XOrOp : IndexBinaryOp<"xor", [Commutative, Pure]> {
%c = index.xor %a, %b
```
}];
+
+ let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index 0b58eb80f93032..5c935c5f4b53e3 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -118,6 +118,32 @@ static OpFoldResult foldBinaryOpChecked(
return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64);
}
+/// Helper for associative and commutative binary ops that can be transformed:
+/// `x = op(v, c1); y = op(x, c2)` -> `tmp = op(c1, c2); y = op(v, tmp)`
+/// where c1 and c2 are constants. It is expected that `tmp` will be folded.
+template <typename BinaryOp>
+LogicalResult
+canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op,
+ PatternRewriter &rewriter) {
+ if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant()))
+ return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
+
+ auto lhsOp = op.getLhs().template getDefiningOp<BinaryOp>();
+ if (!lhsOp)
+ return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not the same BinaryOp");
+
+ if (!mlir::matchPattern(lhsOp.getRhs(), mlir::m_Constant()))
+ return rewriter.notifyMatchFailure(op.getLoc(), "RHS of LHS op is not a constant");
+
+ Value c = rewriter.createOrFold<BinaryOp>(op->getLoc(), op.getRhs(),
+ lhsOp.getRhs());
+ if (c.getDefiningOp<BinaryOp>())
+ return rewriter.notifyMatchFailure(op.getLoc(), "new BinaryOp was not folded");
+
+ rewriter.replaceOpWithNewOp<BinaryOp>(op, lhsOp.getLhs(), c);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
@@ -136,27 +162,9 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
return {};
}
-/// Canonicalize
-/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
-LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
- IntegerAttr c1, c2;
- if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c1)))
- return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
-
- auto add = op.getLhs().getDefiningOp<mlir::index::AddOp>();
- if (!add)
- return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
-
- if (!mlir::matchPattern(add.getRhs(), mlir::m_Constant(&c2)))
- return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
-
- auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
- c1.getInt() + c2.getInt());
- auto newAdd =
- rewriter.create<mlir::index::AddOp>(op->getLoc(), add.getLhs(), c);
- rewriter.replaceOp(op, newAdd);
- return success();
+LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
+ return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
}
//===----------------------------------------------------------------------===//
@@ -200,6 +208,10 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
return {};
}
+LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
+ return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
+}
+
//===----------------------------------------------------------------------===//
// DivSOp
//===----------------------------------------------------------------------===//
@@ -352,6 +364,10 @@ OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
});
}
+LogicalResult MaxSOp::canonicalize(MaxSOp op, PatternRewriter &rewriter) {
+ return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
+}
+
//===----------------------------------------------------------------------===//
// MaxUOp
//===----------------------------------------------------------------------===//
@@ -363,6 +379,10 @@ OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
});
}
+LogicalResult MaxUOp::canonicalize(MaxUOp op, PatternRewriter &rewriter) {
+ return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
+}
+
//===----------------------------------------------------------------------===//
// MinSOp
//===----------------------------------------------------------------------===//
@@ -374,6 +394,10 @@ OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
});
}
+LogicalResult MinSOp::canonicalize(MinSOp op, PatternRewriter &rewriter) {
+ return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
+}
+
//===----------------------------------------------------------------------===//
// MinUOp
//===----------------------------------------------------------------------===//
@@ -385,6 +409,10 @@ OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
});
}
+LogicalResult MinUOp::canonicalize(MinUOp op, PatternRewriter &rewriter) {
+ return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
+}
+
//===----------------------------------------------------------------------===//
// ShlOp
//===----------------------------------------------------------------------===//
@@ -442,6 +470,10 @@ OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
[](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
}
+LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
+ return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
+}
+
//===----------------------------------------------------------------------===//
// OrOp
//===----------------------------------------------------------------------===//
@@ -452,6 +484,10 @@ OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
[](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
}
+LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
+ return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
+}
+
//===----------------------------------------------------------------------===//
// XOrOp
//===----------------------------------------------------------------------===//
@@ -462,6 +498,10 @@ OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
[](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
}
+LogicalResult XOrOp::canonicalize(XOrOp op, PatternRewriter &rewriter) {
+ return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
+}
+
//===----------------------------------------------------------------------===//
// CastSOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index a29b09c11f7f62..45da6ea57d796e 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -32,15 +32,15 @@ func.func @add_overflow() -> (index, index) {
return %2, %3 : index, index
}
-// CHECK-LABEL: @add
+// CHECK-LABEL: @add_fold_constants
func.func @add_fold_constants(%arg: index) -> (index) {
%0 = index.constant 1
%1 = index.constant 2
%2 = index.add %arg, %0
%3 = index.add %2, %1
- // CHECK-DAG: [[C3:%.*]] = index.constant 3
- // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[C3]]
+ // CHECK: [[C3:%.*]] = index.constant 3
+ // CHECK: [[V0:%.*]] = index.add %arg0, [[C3]]
// CHECK: return [[V0]]
return %3 : index
}
@@ -65,6 +65,19 @@ func.func @mul() -> index {
return %2 : index
}
+// CHECK-LABEL: @mul_fold_constants
+func.func @mul_fold_constants(%arg: index) -> (index) {
+ %0 = index.constant 2
+ %1 = index.constant 3
+ %2 = index.mul %arg, %0
+ %3 = index.mul %2, %1
+
+ // CHECK: [[C6:%.*]] = index.constant 6
+ // CHECK: [[V0:%.*]] = index.mul %arg0, [[C6]]
+ // CHECK: return [[V0]]
+ return %3 : index
+}
+
// CHECK-LABEL: @divs
func.func @divs() -> index {
%0 = index.constant -2
@@ -300,6 +313,19 @@ func.func @maxs_edge() -> index {
return %0 : index
}
+// CHECK-LABEL: @maxs_fold_constants
+func.func @maxs_fold_constants(%arg: index) -> (index) {
+ %0 = index.constant -2
+ %1 = index.constant 3
+ %2 = index.maxs %arg, %0
+ %3 = index.maxs %2, %1
+
+ // CHECK: [[C3:%.*]] = index.constant 3
+ // CHECK: [[V0:%.*]] = index.maxs %arg0, [[C3]]
+ // CHECK: return [[V0]]
+ return %3 : index
+}
+
// CHECK-LABEL: @maxu
func.func @maxu() -> index {
%lhs = index.constant -1
@@ -310,6 +336,19 @@ func.func @maxu() -> index {
return %0 : index
}
+// CHECK-LABEL: @maxu_fold_constants
+func.func @maxu_fold_constants(%arg: index) -> (index) {
+ %0 = index.constant 2
+ %1 = index.constant 3
+ %2 = index.maxu %arg, %0
+ %3 = index.maxu %2, %1
+
+ // CHECK: [[C3:%.*]] = index.constant 3
+ // CHECK: [[V0:%.*]] = index.maxu %arg0, [[C3]]
+ // CHECK: return [[V0]]
+ return %3 : index
+}
+
// CHECK-LABEL: @mins
func.func @mins() -> index {
%lhs = index.constant -4
@@ -340,6 +379,19 @@ func.func @mins_nofold_2() -> index {
return %0 : index
}
+// CHECK-LABEL: @mins_fold_constants
+func.func @mins_fold_constants(%arg: index) -> (index) {
+ %0 = index.constant -2
+ %1 = index.constant 3
+ %2 = index.mins %arg, %0
+ %3 = index.mins %2, %1
+
+ // CHECK: [[C2:%.*]] = index.constant -2
+ // CHECK: [[V0:%.*]] = index.mins %arg0, [[C2]]
+ // CHECK: return [[V0]]
+ return %3 : index
+}
+
// CHECK-LABEL: @minu
func.func @minu() -> index {
%lhs = index.constant -1
@@ -350,6 +402,19 @@ func.func @minu() -> index {
return %0 : index
}
+// CHECK-LABEL: @minu_fold_constants
+func.func @minu_fold_constants(%arg: index) -> (index) {
+ %0 = index.constant 2
+ %1 = index.constant 3
+ %2 = index.minu %arg, %0
+ %3 = index.minu %2, %1
+
+ // CHECK: [[C2:%.*]] = index.constant 2
+ // CHECK: [[V0:%.*]] = index.minu %arg0, [[C2]]
+ // CHECK: return [[V0]]
+ return %3 : index
+}
+
// CHECK-LABEL: @shl
func.func @shl() -> index {
%lhs = index.constant 128
@@ -465,6 +530,19 @@ func.func @and() -> index {
return %0 : index
}
+// CHECK-LABEL: @and_fold_constants
+func.func @and_fold_constants(%arg: index) -> (index) {
+ %0 = index.constant 5
+ %1 = index.constant 1
+ %2 = index.and %arg, %0
+ %3 = index.and %2, %1
+
+ // CHECK: [[C1:%.*]] = index.constant 1
+ // CHECK: [[V0:%.*]] = index.and %arg0, [[C1]]
+ // CHECK: return [[V0]]
+ return %3 : index
+}
+
// CHECK-LABEL: @or
func.func @or() -> index {
%lhs = index.constant 5
@@ -475,6 +553,19 @@ func.func @or() -> index {
return %0 : index
}
+// CHECK-LABEL: @or_fold_constants
+func.func @or_fold_constants(%arg: index) -> (index) {
+ %0 = index.constant 5
+ %1 = index.constant 1
+ %2 = index.or %arg, %0
+ %3 = index.or %2, %1
+
+ // CHECK: [[C5:%.*]] = index.constant 5
+ // CHECK: [[V0:%.*]] = index.or %arg0, [[C5]]
+ // CHECK: return [[V0]]
+ return %3 : index
+}
+
// CHECK-LABEL: @xor
func.func @xor() -> index {
%lhs = index.constant 5
@@ -485,6 +576,19 @@ func.func @xor() -> index {
return %0 : index
}
+// CHECK-LABEL: @xor_fold_constants
+func.func @xor_fold_constants(%arg: index) -> (index) {
+ %0 = index.constant 5
+ %1 = index.constant 1
+ %2 = index.xor %arg, %0
+ %3 = index.xor %2, %1
+
+ // CHECK: [[C4:%.*]] = index.constant 4
+ // CHECK: [[V0:%.*]] = index.xor %arg0, [[C4]]
+ // CHECK: return [[V0]]
+ return %3 : index
+}
+
// CHECK-LABEL: @cmp
func.func @cmp(%arg0: index) -> (i1, i1, i1, i1, i1, i1) {
%a = index.constant 0
More information about the Mlir-commits
mailing list