[Mlir-commits] [mlir] d8a57c7 - [mlir][arith] Canonicalization patterns for subi.
Slava Zakharin
llvmlistbot at llvm.org
Tue Sep 13 09:55:05 PDT 2022
Author: Slava Zakharin
Date: 2022-09-13T09:42:29-07:00
New Revision: d8a57c778875eb0a2d093b6f1991e77ff9885a85
URL: https://github.com/llvm/llvm-project/commit/d8a57c778875eb0a2d093b6f1991e77ff9885a85
DIFF: https://github.com/llvm/llvm-project/commit/d8a57c778875eb0a2d093b6f1991e77ff9885a85.diff
LOG: [mlir][arith] Canonicalization patterns for subi.
subi(addi(a, b), b) -> a
subi(addi(a, b), a) -> b
subi(subi(a, b), a) -> subi(0, b)
Differential Revision: https://reviews.llvm.org/D133615
Added:
Modified:
mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td
mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/test/Dialect/Arithmetic/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td
index 647ddfee74d0f..0dd8039ec0b4e 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td
@@ -12,6 +12,9 @@
include "mlir/IR/PatternBase.td"
include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td"
+// Create zero attribute of type matching the argument's type.
+def GetZeroAttr : NativeCodeCall<"$_builder.getZeroAttr($0.getType())">;
+
// Add two integer attributes and create a new one with the result.
def AddIntAttrs : NativeCodeCall<"addIntegerAttrs($_builder, $0, $1, $2)">;
@@ -92,6 +95,11 @@ def SubILHSSubConstantLHS :
(Arith_SubIOp (Arith_ConstantOp APIntAttr:$c0), $x)),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>;
+// subi(subi(a, b), a) -> subi(0, b)
+def SubISubILHSRHSLHS :
+ Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y), $x),
+ (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y)>;
+
//===----------------------------------------------------------------------===//
// XOrIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index e57422b9d341e..bd693478a3533 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -305,16 +305,24 @@ OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
if (matchPattern(getRhs(), m_Zero()))
return getLhs();
+ if (auto add = getLhs().getDefiningOp<AddIOp>()) {
+ // subi(addi(a, b), b) -> a
+ if (getRhs() == add.getRhs())
+ return add.getLhs();
+ // subi(addi(a, b), a) -> b
+ if (getRhs() == add.getLhs())
+ return add.getRhs();
+ }
+
return constFoldBinaryOp<IntegerAttr>(
operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
}
void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns
- .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
- SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
- context);
+ patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
+ SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
+ SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index 7076ab81e6233..649da010cd359 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -480,6 +480,16 @@ func.func @tripleSubAdd1(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @subSub0
+// CHECK: %[[c0:.+]] = arith.constant 0 : index
+// CHECK: %[[add:.+]] = arith.subi %[[c0]], %arg1 : index
+// CHECK: return %[[add]]
+func.func @subSub0(%arg0: index, %arg1: index) -> index {
+ %sub1 = arith.subi %arg0, %arg1 : index
+ %sub2 = arith.subi %sub1, %arg0 : index
+ return %sub2 : index
+}
+
// CHECK-LABEL: @tripleSubSub0
// CHECK: %[[cres:.+]] = arith.constant 25 : index
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
@@ -528,6 +538,22 @@ func.func @tripleSubSub3(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @subAdd1
+// CHECK-NEXT: return %arg0
+func.func @subAdd1(%arg0: index, %arg1 : index) -> index {
+ %add = arith.addi %arg0, %arg1 : index
+ %sub = arith.subi %add, %arg1 : index
+ return %sub : index
+}
+
+// CHECK-LABEL: @subAdd2
+// CHECK-NEXT: return %arg1
+func.func @subAdd2(%arg0: index, %arg1 : index) -> index {
+ %add = arith.addi %arg0, %arg1 : index
+ %sub = arith.subi %add, %arg0 : index
+ return %sub : index
+}
+
// CHECK-LABEL: @doubleAddSub1
// CHECK-NEXT: return %arg0
func.func @doubleAddSub1(%arg0: index, %arg1 : index) -> index {
More information about the Mlir-commits
mailing list