[Mlir-commits] [mlir] fbe91fe - [mlir][arith] Canonicalize `addi(x, muli(y, -1))` -> `subi(x, y)`
Jakub Kuderski
llvmlistbot at llvm.org
Mon Mar 6 16:29:50 PST 2023
Author: Jakub Kuderski
Date: 2023-03-06T19:28:39-05:00
New Revision: fbe91fe2cc3bd2c907e63f30db719204aaaf3973
URL: https://github.com/llvm/llvm-project/commit/fbe91fe2cc3bd2c907e63f30db719204aaaf3973
DIFF: https://github.com/llvm/llvm-project/commit/fbe91fe2cc3bd2c907e63f30db719204aaaf3973.diff
LOG: [mlir][arith] Canonicalize `addi(x, muli(y, -1))` -> `subi(x, y)`
These propagate all the way down to SPIR-V and result in some fishy code
with large constants.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D145423
Added:
Modified:
mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index abf3db1728dcf..7c687142247a6 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -49,6 +49,27 @@ def AddISubConstantLHS :
(ConstantLikeMatcher APIntAttr:$c1)),
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;
+def IsScalarOrSplatNegativeOne :
+ Constraint<And<[
+ CPred<"succeeded(getIntOrSplatIntValue($0))">,
+ CPred<"getIntOrSplatIntValue($0)->isAllOnes()">]>>;
+
+// addi(x, muli(y, -1)) -> subi(x, y)
+def AddIMulNegativeOneRhs :
+ Pat<(Arith_AddIOp
+ $x,
+ (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0))),
+ (Arith_SubIOp $x, $y),
+ [(IsScalarOrSplatNegativeOne $c0)]>;
+
+// addi(muli(x, -1), y) -> subi(y, x)
+def AddIMulNegativeOneLhs :
+ Pat<(Arith_AddIOp
+ (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0)),
+ $y),
+ (Arith_SubIOp $y, $x),
+ [(IsScalarOrSplatNegativeOne $c0)]>;
+
//===----------------------------------------------------------------------===//
// AddUIExtendedOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index f6308a6b000b0..e56f4526291aa 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -258,8 +258,8 @@ OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
- context);
+ patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
+ AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index eaafa9e93ceaa..396f5ee3dc6ea 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -735,6 +735,72 @@ func.func @doubleAddSub2(%arg0: index, %arg1 : index) -> index {
return %add : index
}
+// CHECK-LABEL: @addiMuliToSubiRhsI32
+// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
+// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32
+// CHECK: return %[[SUB]]
+func.func @addiMuliToSubiRhsI32(%arg0: i32, %arg1: i32) -> i32 {
+ %c-1 = arith.constant -1 : i32
+ %neg = arith.muli %arg1, %c-1 : i32
+ %add = arith.addi %arg0, %neg : i32
+ return %add : i32
+}
+
+// CHECK-LABEL: @addiMuliToSubiRhsIndex
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
+// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : index
+// CHECK: return %[[SUB]]
+func.func @addiMuliToSubiRhsIndex(%arg0: index, %arg1: index) -> index {
+ %c-1 = arith.constant -1 : index
+ %neg = arith.muli %arg1, %c-1 : index
+ %add = arith.addi %arg0, %neg : index
+ return %add : index
+}
+
+// CHECK-LABEL: @addiMuliToSubiRhsVector
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi64>, %[[ARG1:.+]]: vector<3xi64>)
+// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : vector<3xi64>
+// CHECK: return %[[SUB]]
+func.func @addiMuliToSubiRhsVector(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi64> {
+ %c-1 = arith.constant dense<-1> : vector<3xi64>
+ %neg = arith.muli %arg1, %c-1 : vector<3xi64>
+ %add = arith.addi %arg0, %neg : vector<3xi64>
+ return %add : vector<3xi64>
+}
+
+// CHECK-LABEL: @addiMuliToSubiLhsI32
+// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
+// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32
+// CHECK: return %[[SUB]]
+func.func @addiMuliToSubiLhsI32(%arg0: i32, %arg1: i32) -> i32 {
+ %c-1 = arith.constant -1 : i32
+ %neg = arith.muli %arg1, %c-1 : i32
+ %add = arith.addi %neg, %arg0 : i32
+ return %add : i32
+}
+
+// CHECK-LABEL: @addiMuliToSubiLhsIndex
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
+// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : index
+// CHECK: return %[[SUB]]
+func.func @addiMuliToSubiLhsIndex(%arg0: index, %arg1: index) -> index {
+ %c-1 = arith.constant -1 : index
+ %neg = arith.muli %arg1, %c-1 : index
+ %add = arith.addi %neg, %arg0 : index
+ return %add : index
+}
+
+// CHECK-LABEL: @addiMuliToSubiLhsVector
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi64>, %[[ARG1:.+]]: vector<3xi64>)
+// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : vector<3xi64>
+// CHECK: return %[[SUB]]
+func.func @addiMuliToSubiLhsVector(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi64> {
+ %c-1 = arith.constant dense<-1> : vector<3xi64>
+ %neg = arith.muli %arg1, %c-1 : vector<3xi64>
+ %add = arith.addi %neg, %arg0 : vector<3xi64>
+ return %add : vector<3xi64>
+}
+
// CHECK-LABEL: @adduiExtendedZeroRhs
// CHECK-NEXT: %[[false:.+]] = arith.constant false
// CHECK-NEXT: return %arg0, %[[false]]
More information about the Mlir-commits
mailing list