[Mlir-commits] [mlir] d8a57c7 - [mlir][arith] Canonicalization patterns for subi.

Slava Zakharin llvmlistbot at llvm.org
Tue Sep 13 09:55:05 PDT 2022


Author: Slava Zakharin
Date: 2022-09-13T09:42:29-07:00
New Revision: d8a57c778875eb0a2d093b6f1991e77ff9885a85

URL: https://github.com/llvm/llvm-project/commit/d8a57c778875eb0a2d093b6f1991e77ff9885a85
DIFF: https://github.com/llvm/llvm-project/commit/d8a57c778875eb0a2d093b6f1991e77ff9885a85.diff

LOG: [mlir][arith] Canonicalization patterns for subi.

subi(addi(a, b), b) -> a
subi(addi(a, b), a) -> b
subi(subi(a, b), a) -> subi(0, b)

Differential Revision: https://reviews.llvm.org/D133615

Added: 
    

Modified: 
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/test/Dialect/Arithmetic/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td
index 647ddfee74d0f..0dd8039ec0b4e 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td
@@ -12,6 +12,9 @@
 include "mlir/IR/PatternBase.td"
 include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td"
 
+// Create zero attribute of type matching the argument's type.
+def GetZeroAttr : NativeCodeCall<"$_builder.getZeroAttr($0.getType())">;
+
 // Add two integer attributes and create a new one with the result.
 def AddIntAttrs : NativeCodeCall<"addIntegerAttrs($_builder, $0, $1, $2)">;
 
@@ -92,6 +95,11 @@ def SubILHSSubConstantLHS :
           (Arith_SubIOp (Arith_ConstantOp APIntAttr:$c0), $x)),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>;
 
+// subi(subi(a, b), a) -> subi(0, b)
+def SubISubILHSRHSLHS :
+    Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y), $x),
+        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y)>;
+
 //===----------------------------------------------------------------------===//
 // XOrIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index e57422b9d341e..bd693478a3533 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -305,16 +305,24 @@ OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
   if (matchPattern(getRhs(), m_Zero()))
     return getLhs();
 
+  if (auto add = getLhs().getDefiningOp<AddIOp>()) {
+    // subi(addi(a, b), b) -> a
+    if (getRhs() == add.getRhs())
+      return add.getLhs();
+    // subi(addi(a, b), a) -> b
+    if (getRhs() == add.getLhs())
+      return add.getRhs();
+  }
+
   return constFoldBinaryOp<IntegerAttr>(
       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
 }
 
 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                 MLIRContext *context) {
-  patterns
-      .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
-           SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
-          context);
+  patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
+               SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
+               SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index 7076ab81e6233..649da010cd359 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -480,6 +480,16 @@ func.func @tripleSubAdd1(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @subSub0
+//       CHECK:   %[[c0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[add:.+]] = arith.subi %[[c0]], %arg1 : index
+//       CHECK:   return %[[add]]
+func.func @subSub0(%arg0: index, %arg1: index) -> index {
+  %sub1 = arith.subi %arg0, %arg1 : index
+  %sub2 = arith.subi %sub1, %arg0 : index
+  return %sub2 : index
+}
+
 // CHECK-LABEL: @tripleSubSub0
 //       CHECK:   %[[cres:.+]] = arith.constant 25 : index
 //       CHECK:   %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
@@ -528,6 +538,22 @@ func.func @tripleSubSub3(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @subAdd1
+//  CHECK-NEXT:   return %arg0
+func.func @subAdd1(%arg0: index, %arg1 : index) -> index {
+  %add = arith.addi %arg0, %arg1 : index
+  %sub = arith.subi %add, %arg1 : index
+  return %sub : index
+}
+
+// CHECK-LABEL: @subAdd2
+//  CHECK-NEXT:   return %arg1
+func.func @subAdd2(%arg0: index, %arg1 : index) -> index {
+  %add = arith.addi %arg0, %arg1 : index
+  %sub = arith.subi %add, %arg0 : index
+  return %sub : index
+}
+
 // CHECK-LABEL: @doubleAddSub1
 //  CHECK-NEXT:   return %arg0
 func.func @doubleAddSub1(%arg0: index, %arg1 : index) -> index {


        


More information about the Mlir-commits mailing list