[Mlir-commits] [mlir] fbe91fe - [mlir][arith] Canonicalize `addi(x, muli(y, -1))` -> `subi(x, y)`

Jakub Kuderski llvmlistbot at llvm.org
Mon Mar 6 16:29:50 PST 2023


Author: Jakub Kuderski
Date: 2023-03-06T19:28:39-05:00
New Revision: fbe91fe2cc3bd2c907e63f30db719204aaaf3973

URL: https://github.com/llvm/llvm-project/commit/fbe91fe2cc3bd2c907e63f30db719204aaaf3973
DIFF: https://github.com/llvm/llvm-project/commit/fbe91fe2cc3bd2c907e63f30db719204aaaf3973.diff

LOG: [mlir][arith] Canonicalize `addi(x, muli(y, -1))` -> `subi(x, y)`

These propagate all the way down to SPIR-V and result in some fishy code
with large constants.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D145423

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Dialect/Arith/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index abf3db1728dcf..7c687142247a6 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -49,6 +49,27 @@ def AddISubConstantLHS :
           (ConstantLikeMatcher APIntAttr:$c1)),
         (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;
 
+def IsScalarOrSplatNegativeOne :
+    Constraint<And<[
+      CPred<"succeeded(getIntOrSplatIntValue($0))">,
+      CPred<"getIntOrSplatIntValue($0)->isAllOnes()">]>>;
+
+// addi(x, muli(y, -1)) -> subi(x, y)
+def AddIMulNegativeOneRhs :
+    Pat<(Arith_AddIOp
+           $x,
+           (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0))),
+        (Arith_SubIOp $x, $y),
+        [(IsScalarOrSplatNegativeOne $c0)]>;
+
+// addi(muli(x, -1), y) -> subi(y, x)
+def AddIMulNegativeOneLhs :
+    Pat<(Arith_AddIOp
+           (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0)),
+           $y),
+        (Arith_SubIOp $y, $x),
+        [(IsScalarOrSplatNegativeOne $c0)]>;
+
 //===----------------------------------------------------------------------===//
 // AddUIExtendedOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index f6308a6b000b0..e56f4526291aa 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -258,8 +258,8 @@ OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
 
 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                 MLIRContext *context) {
-  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
-      context);
+  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
+               AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index eaafa9e93ceaa..396f5ee3dc6ea 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -735,6 +735,72 @@ func.func @doubleAddSub2(%arg0: index, %arg1 : index) -> index {
   return %add : index
 }
 
+// CHECK-LABEL: @addiMuliToSubiRhsI32
+//  CHECK-SAME:   (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
+//       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32
+//       CHECK:   return %[[SUB]]
+func.func @addiMuliToSubiRhsI32(%arg0: i32, %arg1: i32) -> i32 {
+  %c-1 = arith.constant -1 : i32
+  %neg = arith.muli %arg1, %c-1 : i32
+  %add = arith.addi %arg0, %neg : i32
+  return %add : i32
+}
+
+// CHECK-LABEL: @addiMuliToSubiRhsIndex
+//  CHECK-SAME:   (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
+//       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : index
+//       CHECK:   return %[[SUB]]
+func.func @addiMuliToSubiRhsIndex(%arg0: index, %arg1: index) -> index {
+  %c-1 = arith.constant -1 : index
+  %neg = arith.muli %arg1, %c-1 : index
+  %add = arith.addi %arg0, %neg : index
+  return %add : index
+}
+
+// CHECK-LABEL: @addiMuliToSubiRhsVector
+//  CHECK-SAME:   (%[[ARG0:.+]]: vector<3xi64>, %[[ARG1:.+]]: vector<3xi64>)
+//       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : vector<3xi64>
+//       CHECK:   return %[[SUB]]
+func.func @addiMuliToSubiRhsVector(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi64> {
+  %c-1 = arith.constant dense<-1> : vector<3xi64>
+  %neg = arith.muli %arg1, %c-1 : vector<3xi64>
+  %add = arith.addi %arg0, %neg : vector<3xi64>
+  return %add : vector<3xi64>
+}
+
+// CHECK-LABEL: @addiMuliToSubiLhsI32
+//  CHECK-SAME:   (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
+//       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32
+//       CHECK:   return %[[SUB]]
+func.func @addiMuliToSubiLhsI32(%arg0: i32, %arg1: i32) -> i32 {
+  %c-1 = arith.constant -1 : i32
+  %neg = arith.muli %arg1, %c-1 : i32
+  %add = arith.addi %neg, %arg0 : i32
+  return %add : i32
+}
+
+// CHECK-LABEL: @addiMuliToSubiLhsIndex
+//  CHECK-SAME:   (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
+//       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : index
+//       CHECK:   return %[[SUB]]
+func.func @addiMuliToSubiLhsIndex(%arg0: index, %arg1: index) -> index {
+  %c-1 = arith.constant -1 : index
+  %neg = arith.muli %arg1, %c-1 : index
+  %add = arith.addi %neg, %arg0 : index
+  return %add : index
+}
+
+// CHECK-LABEL: @addiMuliToSubiLhsVector
+//  CHECK-SAME:   (%[[ARG0:.+]]: vector<3xi64>, %[[ARG1:.+]]: vector<3xi64>)
+//       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : vector<3xi64>
+//       CHECK:   return %[[SUB]]
+func.func @addiMuliToSubiLhsVector(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi64> {
+  %c-1 = arith.constant dense<-1> : vector<3xi64>
+  %neg = arith.muli %arg1, %c-1 : vector<3xi64>
+  %add = arith.addi %neg, %arg0 : vector<3xi64>
+  return %add : vector<3xi64>
+}
+
 // CHECK-LABEL: @adduiExtendedZeroRhs
 //  CHECK-NEXT:   %[[false:.+]] = arith.constant false
 //  CHECK-NEXT:   return %arg0, %[[false]]


        


More information about the Mlir-commits mailing list