[Mlir-commits] [mlir] a3fa6b8 - [mlir][Arith][NFC] Migrate Arith dialect to the new fold API
Markus Böck
llvmlistbot at llvm.org
Wed Jan 11 07:17:18 PST 2023
Author: Markus Böck
Date: 2023-01-11T16:16:22+01:00
New Revision: a3fa6b86fa424dab1f1b6cbc38e27c8c3364f2f4
URL: https://github.com/llvm/llvm-project/commit/a3fa6b86fa424dab1f1b6cbc38e27c8c3364f2f4
DIFF: https://github.com/llvm/llvm-project/commit/a3fa6b86fa424dab1f1b6cbc38e27c8c3364f2f4.diff
LOG: [mlir][Arith][NFC] Migrate Arith dialect to the new fold API
This is the dialect in-tree with the most `fold` method implementations by far. This patch simply changes all implementations to make use of the new signature.
Admittedly, the code readability does not get a lot better in this case, simply due to most methods making use of `constFoldBinaryOp`. I did not modify that function or its interface as part of this patch, but might be something to consider in the future.
Differential Revision: https://reviews.llvm.org/D141490
Added:
Modified:
mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 78fd7bdf012f8..065d8cfebeaf3 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -24,6 +24,7 @@ def Arith_Dialect : Dialect {
let hasConstantMaterializer = 1;
let useDefaultAttributePrinterParser = 1;
+ let useFoldAPI = kEmitFoldAdaptorFolder;
}
// The predicate indicates the type of the comparison to perform:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index f812b3c61b366..63febd8577369 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -164,9 +164,7 @@ bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
}
-OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
- return getValue();
-}
+OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
int64_t value, unsigned width) {
@@ -217,7 +215,7 @@ bool arith::ConstantIndexOp::classof(Operation *op) {
// AddIOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
// addi(x, 0) -> x
if (matchPattern(getRhs(), m_Zero()))
return getLhs();
@@ -233,7 +231,8 @@ OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
return sub.getLhs();
return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
+ adaptor.getOperands(),
+ [](APInt a, const APInt &b) { return std::move(a) + b; });
}
void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -260,7 +259,7 @@ static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
}
LogicalResult
-arith::AddUIExtendedOp::fold(ArrayRef<Attribute> operands,
+arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
Type overflowTy = getOverflow().getType();
// addui_extended(x, 0) -> x, false
@@ -280,21 +279,22 @@ arith::AddUIExtendedOp::fold(ArrayRef<Attribute> operands,
// `constFoldBinaryOp` again to calculate the overflow bit because the
// constructed attribute is of the same element type as both operands.
if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, const APInt &b) { return std::move(a) + b; })) {
+ adaptor.getOperands(),
+ [](APInt a, const APInt &b) { return std::move(a) + b; })) {
Attribute overflowAttr;
- if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
+ if (auto lhs = adaptor.getLhs().dyn_cast<IntegerAttr>()) {
// Both arguments are scalars, calculate the scalar overflow value.
auto sum = sumAttr.cast<IntegerAttr>();
overflowAttr = IntegerAttr::get(
overflowTy,
calculateUnsignedOverflow(sum.getValue(), lhs.getValue()));
- } else if (auto lhs = operands[0].dyn_cast<SplatElementsAttr>()) {
+ } else if (auto lhs = adaptor.getLhs().dyn_cast<SplatElementsAttr>()) {
// Both arguments are splats, calculate the splat overflow value.
auto sum = sumAttr.cast<SplatElementsAttr>();
APInt overflow = calculateUnsignedOverflow(sum.getSplatValue<APInt>(),
lhs.getSplatValue<APInt>());
overflowAttr = SplatElementsAttr::get(overflowTy, overflow);
- } else if (auto lhs = operands[0].dyn_cast<ElementsAttr>()) {
+ } else if (auto lhs = adaptor.getLhs().dyn_cast<ElementsAttr>()) {
// Othwerwise calculate element-wise overflow values.
auto sum = sumAttr.cast<ElementsAttr>();
const auto numElems = static_cast<size_t>(sum.getNumElements());
@@ -328,7 +328,7 @@ void arith::AddUIExtendedOp::getCanonicalizationPatterns(
// SubIOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
// subi(x,x) -> 0
if (getOperand(0) == getOperand(1))
return Builder(getContext()).getZeroAttr(getType());
@@ -346,7 +346,8 @@ OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
+ adaptor.getOperands(),
+ [](APInt a, const APInt &b) { return std::move(a) - b; });
}
void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -360,7 +361,7 @@ void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// MulIOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
// muli(x, 0) -> 0
if (matchPattern(getRhs(), m_Zero()))
return getRhs();
@@ -371,7 +372,8 @@ OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
// default folder
return constFoldBinaryOp<IntegerAttr>(
- operands, [](const APInt &a, const APInt &b) { return a * b; });
+ adaptor.getOperands(),
+ [](const APInt &a, const APInt &b) { return a * b; });
}
//===----------------------------------------------------------------------===//
@@ -386,11 +388,11 @@ arith::MulSIExtendedOp::getShapeForUnroll() {
}
LogicalResult
-arith::MulSIExtendedOp::fold(ArrayRef<Attribute> operands,
+arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
// mulsi_extended(x, 0) -> 0, 0
if (matchPattern(getRhs(), m_Zero())) {
- Attribute zero = operands[1];
+ Attribute zero = adaptor.getRhs();
results.push_back(zero);
results.push_back(zero);
return success();
@@ -398,10 +400,11 @@ arith::MulSIExtendedOp::fold(ArrayRef<Attribute> operands,
// mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
- operands, [](const APInt &a, const APInt &b) { return a * b; })) {
+ adaptor.getOperands(),
+ [](const APInt &a, const APInt &b) { return a * b; })) {
// Invoke the constant fold helper again to calculate the 'high' result.
Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
- operands, [](const APInt &a, const APInt &b) {
+ adaptor.getOperands(), [](const APInt &a, const APInt &b) {
unsigned bitWidth = a.getBitWidth();
APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
return fullProduct.extractBits(bitWidth, bitWidth);
@@ -433,11 +436,11 @@ arith::MulUIExtendedOp::getShapeForUnroll() {
}
LogicalResult
-arith::MulUIExtendedOp::fold(ArrayRef<Attribute> operands,
+arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
// mului_extended(x, 0) -> 0, 0
if (matchPattern(getRhs(), m_Zero())) {
- Attribute zero = operands[1];
+ Attribute zero = adaptor.getRhs();
results.push_back(zero);
results.push_back(zero);
return success();
@@ -454,10 +457,11 @@ arith::MulUIExtendedOp::fold(ArrayRef<Attribute> operands,
// mului_extended(cst_a, cst_b) -> cst_low, cst_high
if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
- operands, [](const APInt &a, const APInt &b) { return a * b; })) {
+ adaptor.getOperands(),
+ [](const APInt &a, const APInt &b) { return a * b; })) {
// Invoke the constant fold helper again to calculate the 'high' result.
Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
- operands, [](const APInt &a, const APInt &b) {
+ adaptor.getOperands(), [](const APInt &a, const APInt &b) {
unsigned bitWidth = a.getBitWidth();
APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
return fullProduct.extractBits(bitWidth, bitWidth);
@@ -481,21 +485,21 @@ void arith::MulUIExtendedOp::getCanonicalizationPatterns(
// DivUIOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
// divui (x, 1) -> x.
if (matchPattern(getRhs(), m_One()))
return getLhs();
// Don't fold if it would require a division by zero.
bool div0 = false;
- auto result =
- constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
- if (div0 || !b) {
- div0 = true;
- return a;
- }
- return a.udiv(b);
- });
+ auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+ [&](APInt a, const APInt &b) {
+ if (div0 || !b) {
+ div0 = true;
+ return a;
+ }
+ return a.udiv(b);
+ });
return div0 ? Attribute() : result;
}
@@ -510,15 +514,15 @@ Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
// DivSIOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
// divsi (x, 1) -> x.
if (matchPattern(getRhs(), m_One()))
return getLhs();
// Don't fold if it would overflow or if it requires a division by zero.
bool overflowOrDiv0 = false;
- auto result =
- constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
+ auto result = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](APInt a, const APInt &b) {
if (overflowOrDiv0 || !b) {
overflowOrDiv0 = true;
return a;
@@ -557,14 +561,14 @@ static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
// CeilDivUIOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
// ceildivui (x, 1) -> x.
if (matchPattern(getRhs(), m_One()))
return getLhs();
bool overflowOrDiv0 = false;
- auto result =
- constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
+ auto result = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](APInt a, const APInt &b) {
if (overflowOrDiv0 || !b) {
overflowOrDiv0 = true;
return a;
@@ -589,15 +593,15 @@ Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
// CeilDivSIOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
// ceildivsi (x, 1) -> x.
if (matchPattern(getRhs(), m_One()))
return getLhs();
// Don't fold if it would overflow or if it requires a division by zero.
bool overflowOrDiv0 = false;
- auto result =
- constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
+ auto result = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](APInt a, const APInt &b) {
if (overflowOrDiv0 || !b) {
overflowOrDiv0 = true;
return a;
@@ -650,15 +654,15 @@ Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
// FloorDivSIOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
// floordivsi (x, 1) -> x.
if (matchPattern(getRhs(), m_One()))
return getLhs();
// Don't fold if it would overflow or if it requires a division by zero.
bool overflowOrDiv0 = false;
- auto result =
- constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
+ auto result = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](APInt a, const APInt &b) {
if (overflowOrDiv0 || !b) {
overflowOrDiv0 = true;
return a;
@@ -699,21 +703,21 @@ OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
// RemUIOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
// remui (x, 1) -> 0.
if (matchPattern(getRhs(), m_One()))
return Builder(getContext()).getZeroAttr(getType());
// Don't fold if it would require a division by zero.
bool div0 = false;
- auto result =
- constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
- if (div0 || b.isNullValue()) {
- div0 = true;
- return a;
- }
- return a.urem(b);
- });
+ auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+ [&](APInt a, const APInt &b) {
+ if (div0 || b.isNullValue()) {
+ div0 = true;
+ return a;
+ }
+ return a.urem(b);
+ });
return div0 ? Attribute() : result;
}
@@ -722,21 +726,21 @@ OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
// RemSIOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
// remsi (x, 1) -> 0.
if (matchPattern(getRhs(), m_One()))
return Builder(getContext()).getZeroAttr(getType());
// Don't fold if it would require a division by zero.
bool div0 = false;
- auto result =
- constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
- if (div0 || b.isNullValue()) {
- div0 = true;
- return a;
- }
- return a.srem(b);
- });
+ auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+ [&](APInt a, const APInt &b) {
+ if (div0 || b.isNullValue()) {
+ div0 = true;
+ return a;
+ }
+ return a.srem(b);
+ });
return div0 ? Attribute() : result;
}
@@ -762,7 +766,7 @@ static Value foldAndIofAndI(arith::AndIOp op) {
return {};
}
-OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
/// and(x, 0) -> 0
if (matchPattern(getRhs(), m_Zero()))
return getRhs();
@@ -786,31 +790,33 @@ OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
return result;
return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
+ adaptor.getOperands(),
+ [](APInt a, const APInt &b) { return std::move(a) & b; });
}
//===----------------------------------------------------------------------===//
// OrIOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
/// or(x, 0) -> x
if (matchPattern(getRhs(), m_Zero()))
return getLhs();
/// or(x, <all ones>) -> <all ones>
- if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
+ if (auto rhsAttr = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>())
if (rhsAttr.getValue().isAllOnes())
return rhsAttr;
return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
+ adaptor.getOperands(),
+ [](APInt a, const APInt &b) { return std::move(a) | b; });
}
//===----------------------------------------------------------------------===//
// XOrIOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
/// xor(x, 0) -> x
if (matchPattern(getRhs(), m_Zero()))
return getLhs();
@@ -835,7 +841,8 @@ OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
+ adaptor.getOperands(),
+ [](APInt a, const APInt &b) { return std::move(a) ^ b; });
}
void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -847,11 +854,11 @@ void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// NegFOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
/// negf(negf(x)) -> x
if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
return op.getOperand();
- return constFoldUnaryOp<FloatAttr>(operands,
+ return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
[](const APFloat &a) { return -a; });
}
@@ -859,35 +866,35 @@ OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
// AddFOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
// addf(x, -0) -> x
if (matchPattern(getRhs(), m_NegZeroFloat()))
return getLhs();
return constFoldBinaryOp<FloatAttr>(
- operands, [](const APFloat &a, const APFloat &b) { return a + b; });
+ adaptor.getOperands(),
+ [](const APFloat &a, const APFloat &b) { return a + b; });
}
//===----------------------------------------------------------------------===//
// SubFOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
// subf(x, +0) -> x
if (matchPattern(getRhs(), m_PosZeroFloat()))
return getLhs();
return constFoldBinaryOp<FloatAttr>(
- operands, [](const APFloat &a, const APFloat &b) { return a - b; });
+ adaptor.getOperands(),
+ [](const APFloat &a, const APFloat &b) { return a - b; });
}
//===----------------------------------------------------------------------===//
// MaxFOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "maxf takes two operands");
-
+OpFoldResult arith::MaxFOp::fold(FoldAdaptor adaptor) {
// maxf(x,x) -> x
if (getLhs() == getRhs())
return getRhs();
@@ -897,7 +904,7 @@ OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
return getLhs();
return constFoldBinaryOp<FloatAttr>(
- operands,
+ adaptor.getOperands(),
[](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
}
@@ -905,9 +912,7 @@ OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
// MaxSIOp
//===----------------------------------------------------------------------===//
-OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "binary operation takes two operands");
-
+OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
// maxsi(x,x) -> x
if (getLhs() == getRhs())
return getRhs();
@@ -923,7 +928,7 @@ OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
intValue.isMinSignedValue())
return getLhs();
- return constFoldBinaryOp<IntegerAttr>(operands,
+ return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
return llvm::APIntOps::smax(a, b);
});
@@ -933,9 +938,7 @@ OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
// MaxUIOp
//===----------------------------------------------------------------------===//
-OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "binary operation takes two operands");
-
+OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
// maxui(x,x) -> x
if (getLhs() == getRhs())
return getRhs();
@@ -949,7 +952,7 @@ OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
return getLhs();
- return constFoldBinaryOp<IntegerAttr>(operands,
+ return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
return llvm::APIntOps::umax(a, b);
});
@@ -959,9 +962,7 @@ OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
// MinFOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "minf takes two operands");
-
+OpFoldResult arith::MinFOp::fold(FoldAdaptor adaptor) {
// minf(x,x) -> x
if (getLhs() == getRhs())
return getRhs();
@@ -971,7 +972,7 @@ OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
return getLhs();
return constFoldBinaryOp<FloatAttr>(
- operands,
+ adaptor.getOperands(),
[](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
}
@@ -979,9 +980,7 @@ OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
// MinSIOp
//===----------------------------------------------------------------------===//
-OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "binary operation takes two operands");
-
+OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
// minsi(x,x) -> x
if (getLhs() == getRhs())
return getRhs();
@@ -997,7 +996,7 @@ OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
intValue.isMaxSignedValue())
return getLhs();
- return constFoldBinaryOp<IntegerAttr>(operands,
+ return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
return llvm::APIntOps::smin(a, b);
});
@@ -1007,9 +1006,7 @@ OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
// MinUIOp
//===----------------------------------------------------------------------===//
-OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "binary operation takes two operands");
-
+OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
// minui(x,x) -> x
if (getLhs() == getRhs())
return getRhs();
@@ -1023,7 +1020,7 @@ OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
return getLhs();
- return constFoldBinaryOp<IntegerAttr>(operands,
+ return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
return llvm::APIntOps::umin(a, b);
});
@@ -1033,13 +1030,14 @@ OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
// MulFOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
// mulf(x, 1) -> x
if (matchPattern(getRhs(), m_OneFloat()))
return getLhs();
return constFoldBinaryOp<FloatAttr>(
- operands, [](const APFloat &a, const APFloat &b) { return a * b; });
+ adaptor.getOperands(),
+ [](const APFloat &a, const APFloat &b) { return a * b; });
}
void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -1051,13 +1049,14 @@ void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// DivFOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
// divf(x, 1) -> x
if (matchPattern(getRhs(), m_OneFloat()))
return getLhs();
return constFoldBinaryOp<FloatAttr>(
- operands, [](const APFloat &a, const APFloat &b) { return a / b; });
+ adaptor.getOperands(),
+ [](const APFloat &a, const APFloat &b) { return a / b; });
}
void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -1069,8 +1068,8 @@ void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// RemFOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::RemFOp::fold(ArrayRef<Attribute> operands) {
- return constFoldBinaryOp<FloatAttr>(operands,
+OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
+ return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
[](const APFloat &a, const APFloat &b) {
APFloat result(a);
(void)result.remainder(b);
@@ -1170,7 +1169,7 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
// ExtUIOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
getInMutable().assign(lhs.getIn());
return getResult();
@@ -1179,7 +1178,8 @@ OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
Type resType = getElementTypeOrSelf(getType());
unsigned bitWidth = resType.cast<IntegerType>().getWidth();
return constFoldCastOp<IntegerAttr, IntegerAttr>(
- operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
+ adaptor.getOperands(), getType(),
+ [bitWidth](const APInt &a, bool &castStatus) {
return a.zext(bitWidth);
});
}
@@ -1196,7 +1196,7 @@ LogicalResult arith::ExtUIOp::verify() {
// ExtSIOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
getInMutable().assign(lhs.getIn());
return getResult();
@@ -1205,7 +1205,8 @@ OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
Type resType = getElementTypeOrSelf(getType());
unsigned bitWidth = resType.cast<IntegerType>().getWidth();
return constFoldCastOp<IntegerAttr, IntegerAttr>(
- operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
+ adaptor.getOperands(), getType(),
+ [bitWidth](const APInt &a, bool &castStatus) {
return a.sext(bitWidth);
});
}
@@ -1237,9 +1238,7 @@ LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
// TruncIOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 1 && "unary operation takes one operand");
-
+OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
// trunci(zexti(a)) -> a
// trunci(sexti(a)) -> a
if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
@@ -1255,7 +1254,8 @@ OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
Type resType = getElementTypeOrSelf(getType());
unsigned bitWidth = resType.cast<IntegerType>().getWidth();
return constFoldCastOp<IntegerAttr, IntegerAttr>(
- operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
+ adaptor.getOperands(), getType(),
+ [bitWidth](const APInt &a, bool &castStatus) {
return a.trunc(bitWidth);
});
}
@@ -1280,10 +1280,8 @@ LogicalResult arith::TruncIOp::verify() {
/// Perform safe const propagation for truncf, i.e. only propagate if FP value
/// can be represented without precision loss or rounding.
-OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 1 && "unary operation takes one operand");
-
- auto constOperand = operands.front();
+OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
+ auto constOperand = adaptor.getIn();
if (!constOperand || !constOperand.isa<FloatAttr>())
return {};
@@ -1348,10 +1346,11 @@ bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
}
-OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
Type resEleType = getElementTypeOrSelf(getType());
return constFoldCastOp<IntegerAttr, FloatAttr>(
- operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
+ adaptor.getOperands(), getType(),
+ [&resEleType](const APInt &a, bool &castStatus) {
FloatType floatTy = resEleType.cast<FloatType>();
APFloat apf(floatTy.getFloatSemantics(),
APInt::getZero(floatTy.getWidth()));
@@ -1369,10 +1368,11 @@ bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
}
-OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
Type resEleType = getElementTypeOrSelf(getType());
return constFoldCastOp<IntegerAttr, FloatAttr>(
- operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
+ adaptor.getOperands(), getType(),
+ [&resEleType](const APInt &a, bool &castStatus) {
FloatType floatTy = resEleType.cast<FloatType>();
APFloat apf(floatTy.getFloatSemantics(),
APInt::getZero(floatTy.getWidth()));
@@ -1389,11 +1389,12 @@ bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
}
-OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
Type resType = getElementTypeOrSelf(getType());
unsigned bitWidth = resType.cast<IntegerType>().getWidth();
return constFoldCastOp<FloatAttr, IntegerAttr>(
- operands, getType(), [&bitWidth](const APFloat &a, bool &castStatus) {
+ adaptor.getOperands(), getType(),
+ [&bitWidth](const APFloat &a, bool &castStatus) {
bool ignored;
APSInt api(bitWidth, /*isUnsigned=*/true);
castStatus = APFloat::opInvalidOp !=
@@ -1410,11 +1411,12 @@ bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
}
-OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
Type resType = getElementTypeOrSelf(getType());
unsigned bitWidth = resType.cast<IntegerType>().getWidth();
return constFoldCastOp<FloatAttr, IntegerAttr>(
- operands, getType(), [&bitWidth](const APFloat &a, bool &castStatus) {
+ adaptor.getOperands(), getType(),
+ [&bitWidth](const APFloat &a, bool &castStatus) {
bool ignored;
APSInt api(bitWidth, /*isUnsigned=*/false);
castStatus = APFloat::opInvalidOp !=
@@ -1445,11 +1447,11 @@ bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
return areIndexCastCompatible(inputs, outputs);
}
-OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
// index_cast(constant) -> constant
// A little hack because we go through int. Otherwise, the size of the
// constant might need to change.
- if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
+ if (auto value = adaptor.getIn().dyn_cast_or_null<IntegerAttr>())
return IntegerAttr::get(getType(), value.getInt());
return {};
@@ -1469,11 +1471,11 @@ bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
return areIndexCastCompatible(inputs, outputs);
}
-OpFoldResult arith::IndexCastUIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
// index_castui(constant) -> constant
// A little hack because we go through int. Otherwise, the size of the
// constant might need to change.
- if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
+ if (auto value = adaptor.getIn().dyn_cast_or_null<IntegerAttr>())
return IntegerAttr::get(getType(), value.getValue().getZExtValue());
return {};
@@ -1502,11 +1504,9 @@ bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
}
-OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 1 && "bitcast op expects 1 operand");
-
+OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
auto resType = getType();
- auto operand = operands[0];
+ auto operand = adaptor.getIn();
if (!operand)
return {};
@@ -1620,9 +1620,7 @@ static std::optional<int64_t> getIntegerWidth(Type t) {
return std::nullopt;
}
-OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "cmpi takes two operands");
-
+OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
// cmpi(pred, x, x)
if (getLhs() == getRhs()) {
auto val = applyCmpPredicateToEqualOperands(getPredicate());
@@ -1649,7 +1647,7 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
}
// Move constant to the right side.
- if (operands[0] && !operands[1]) {
+ if (adaptor.getLhs() && !adaptor.getRhs()) {
// Do not use invertPredicate, as it will change eq to ne and vice versa.
using Pred = CmpIPredicate;
const std::pair<Pred, Pred> invPreds[] = {
@@ -1672,13 +1670,13 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
llvm_unreachable("unknown cmpi predicate kind");
}
- auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
+ auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
if (!lhs)
return {};
// We are moving constants to the right side; So if lhs is constant rhs is
// guaranteed to be a constant.
- auto rhs = operands.back().cast<IntegerAttr>();
+ auto rhs = adaptor.getRhs().cast<IntegerAttr>();
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
return BoolAttr::get(getContext(), val);
@@ -1741,11 +1739,9 @@ bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
llvm_unreachable("unknown cmpf predicate kind");
}
-OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "cmpf takes two operands");
-
- auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
- auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
+OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
+ auto lhs = adaptor.getLhs().dyn_cast_or_null<FloatAttr>();
+ auto rhs = adaptor.getRhs().dyn_cast_or_null<FloatAttr>();
// If one operand is NaN, making them both NaN does not change the result.
if (lhs && lhs.getValue().isNaN())
@@ -2123,7 +2119,7 @@ void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<SelectI1Simplify, SelectToExtUI>(context);
}
-OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
Value trueVal = getTrueValue();
Value falseVal = getFalseValue();
if (trueVal == falseVal)
@@ -2220,14 +2216,14 @@ LogicalResult arith::SelectOp::verify() {
// ShLIOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
// shli(x, 0) -> x
if (matchPattern(getRhs(), m_Zero()))
return getLhs();
// Don't fold if shifting more than the bit width.
bool bounded = false;
auto result = constFoldBinaryOp<IntegerAttr>(
- operands, [&](const APInt &a, const APInt &b) {
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
bounded = b.ule(b.getBitWidth());
return a.shl(b);
});
@@ -2238,14 +2234,14 @@ OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
// ShRUIOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
// shrui(x, 0) -> x
if (matchPattern(getRhs(), m_Zero()))
return getLhs();
// Don't fold if shifting more than the bit width.
bool bounded = false;
auto result = constFoldBinaryOp<IntegerAttr>(
- operands, [&](const APInt &a, const APInt &b) {
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
bounded = b.ule(b.getBitWidth());
return a.lshr(b);
});
@@ -2256,14 +2252,14 @@ OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
// ShRSIOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
// shrsi(x, 0) -> x
if (matchPattern(getRhs(), m_Zero()))
return getLhs();
// Don't fold if shifting more than the bit width.
bool bounded = false;
auto result = constFoldBinaryOp<IntegerAttr>(
- operands, [&](const APInt &a, const APInt &b) {
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
bounded = b.ule(b.getBitWidth());
return a.ashr(b);
});
More information about the Mlir-commits
mailing list