[Mlir-commits] [mlir] abc17a6 - [mlir][Arithmetic] Use matchPattern to simplify code.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 22 01:43:02 PDT 2022
Author: jacquesguan
Date: 2022-04-22T08:42:51Z
New Revision: abc17a67519747be36f1fd03e227c5103da4c677
URL: https://github.com/llvm/llvm-project/commit/abc17a67519747be36f1fd03e227c5103da4c677
DIFF: https://github.com/llvm/llvm-project/commit/abc17a67519747be36f1fd03e227c5103da4c677.diff
LOG: [mlir][Arithmetic] Use matchPattern to simplify code.
This patch replaces some code with matchPattern and move them before the constant folder function in order to avoid redundant invoking.
Differential Revision: https://reviews.llvm.org/D124235
Added:
Modified:
mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 1fa4b1b8032a2..5a104400c48a8 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -262,6 +262,10 @@ OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
+ // divui (x, 1) -> x.
+ if (matchPattern(getRhs(), m_One()))
+ return getLhs();
+
// Don't fold if it would require a division by zero.
bool div0 = false;
auto result =
@@ -273,15 +277,6 @@ OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
return a.udiv(b);
});
- // Fold out division by one. Assumes all tensors of all ones are splats.
- if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
- if (rhs.getValue() == 1)
- return getLhs();
- } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
- if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
- return getLhs();
- }
-
return div0 ? Attribute() : result;
}
@@ -290,6 +285,10 @@ OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
+ // divsi (x, 1) -> x.
+ if (matchPattern(getRhs(), m_One()))
+ return getLhs();
+
// Don't fold if it would overflow or if it requires a division by zero.
bool overflowOrDiv0 = false;
auto result =
@@ -301,15 +300,6 @@ OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
return a.sdiv_ov(b, overflowOrDiv0);
});
- // Fold out division by one. Assumes all tensors of all ones are splats.
- if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
- if (rhs.getValue() == 1)
- return getLhs();
- } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
- if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
- return getLhs();
- }
-
return overflowOrDiv0 ? Attribute() : result;
}
@@ -330,6 +320,10 @@ static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
//===----------------------------------------------------------------------===//
OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
+ // ceildivui (x, 1) -> x.
+ if (matchPattern(getRhs(), m_One()))
+ return getLhs();
+
bool overflowOrDiv0 = false;
auto result =
constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
@@ -343,15 +337,6 @@ OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
APInt one(a.getBitWidth(), 1, true);
return quotient.uadd_ov(one, overflowOrDiv0);
});
- // Fold out ceil division by one. Assumes all tensors of all ones are
- // splats.
- if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
- if (rhs.getValue() == 1)
- return getLhs();
- } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
- if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
- return getLhs();
- }
return overflowOrDiv0 ? Attribute() : result;
}
@@ -361,6 +346,10 @@ OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
+ // ceildivsi (x, 1) -> x.
+ if (matchPattern(getRhs(), m_One()))
+ return getLhs();
+
// Don't fold if it would overflow or if it requires a division by zero.
bool overflowOrDiv0 = false;
auto result =
@@ -398,16 +387,6 @@ OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
return zero.ssub_ov(div, overflowOrDiv0);
});
- // Fold out ceil division by one. Assumes all tensors of all ones are
- // splats.
- if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
- if (rhs.getValue() == 1)
- return getLhs();
- } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
- if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
- return getLhs();
- }
-
return overflowOrDiv0 ? Attribute() : result;
}
@@ -416,6 +395,10 @@ OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
+ // floordivsi (x, 1) -> x.
+ if (matchPattern(getRhs(), m_One()))
+ return getLhs();
+
// Don't fold if it would overflow or if it requires a division by zero.
bool overflowOrDiv0 = false;
auto result =
@@ -453,16 +436,6 @@ OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
return zero.ssub_ov(ceil, overflowOrDiv0);
});
- // Fold out floor division by one. Assumes all tensors of all ones are
- // splats.
- if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
- if (rhs.getValue() == 1)
- return getLhs();
- } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
- if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
- return getLhs();
- }
-
return overflowOrDiv0 ? Attribute() : result;
}
More information about the Mlir-commits
mailing list