[Mlir-commits] [mlir] a3fa6b8 - [mlir][Arith][NFC] Migrate Arith dialect to the new fold API

Markus Böck llvmlistbot at llvm.org
Wed Jan 11 07:17:18 PST 2023


Author: Markus Böck
Date: 2023-01-11T16:16:22+01:00
New Revision: a3fa6b86fa424dab1f1b6cbc38e27c8c3364f2f4

URL: https://github.com/llvm/llvm-project/commit/a3fa6b86fa424dab1f1b6cbc38e27c8c3364f2f4
DIFF: https://github.com/llvm/llvm-project/commit/a3fa6b86fa424dab1f1b6cbc38e27c8c3364f2f4.diff

LOG: [mlir][Arith][NFC] Migrate Arith dialect to the new fold API

This is the dialect in-tree with the most `fold` method implementations by far. This patch simply changes all implementations to make use of the new signature.

Admittedly, the code readability does not get a lot better in this case, simply due to most methods making use of `constFoldBinaryOp`. I did not modify that function or its interface as part of this patch, but might be something to consider in the future.

Differential Revision: https://reviews.llvm.org/D141490

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 78fd7bdf012f8..065d8cfebeaf3 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -24,6 +24,7 @@ def Arith_Dialect : Dialect {
 
   let hasConstantMaterializer = 1;
   let useDefaultAttributePrinterParser = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 // The predicate indicates the type of the comparison to perform:

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index f812b3c61b366..63febd8577369 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -164,9 +164,7 @@ bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
 }
 
-OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
-  return getValue();
-}
+OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
 
 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
                                  int64_t value, unsigned width) {
@@ -217,7 +215,7 @@ bool arith::ConstantIndexOp::classof(Operation *op) {
 // AddIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
   // addi(x, 0) -> x
   if (matchPattern(getRhs(), m_Zero()))
     return getLhs();
@@ -233,7 +231,8 @@ OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
       return sub.getLhs();
 
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
+      adaptor.getOperands(),
+      [](APInt a, const APInt &b) { return std::move(a) + b; });
 }
 
 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -260,7 +259,7 @@ static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
 }
 
 LogicalResult
-arith::AddUIExtendedOp::fold(ArrayRef<Attribute> operands,
+arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
                              SmallVectorImpl<OpFoldResult> &results) {
   Type overflowTy = getOverflow().getType();
   // addui_extended(x, 0) -> x, false
@@ -280,21 +279,22 @@ arith::AddUIExtendedOp::fold(ArrayRef<Attribute> operands,
   // `constFoldBinaryOp` again to calculate the overflow bit because the
   // constructed attribute is of the same element type as both operands.
   if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
-          operands, [](APInt a, const APInt &b) { return std::move(a) + b; })) {
+          adaptor.getOperands(),
+          [](APInt a, const APInt &b) { return std::move(a) + b; })) {
     Attribute overflowAttr;
-    if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
+    if (auto lhs = adaptor.getLhs().dyn_cast<IntegerAttr>()) {
       // Both arguments are scalars, calculate the scalar overflow value.
       auto sum = sumAttr.cast<IntegerAttr>();
       overflowAttr = IntegerAttr::get(
           overflowTy,
           calculateUnsignedOverflow(sum.getValue(), lhs.getValue()));
-    } else if (auto lhs = operands[0].dyn_cast<SplatElementsAttr>()) {
+    } else if (auto lhs = adaptor.getLhs().dyn_cast<SplatElementsAttr>()) {
       // Both arguments are splats, calculate the splat overflow value.
       auto sum = sumAttr.cast<SplatElementsAttr>();
       APInt overflow = calculateUnsignedOverflow(sum.getSplatValue<APInt>(),
                                                  lhs.getSplatValue<APInt>());
       overflowAttr = SplatElementsAttr::get(overflowTy, overflow);
-    } else if (auto lhs = operands[0].dyn_cast<ElementsAttr>()) {
+    } else if (auto lhs = adaptor.getLhs().dyn_cast<ElementsAttr>()) {
       // Othwerwise calculate element-wise overflow values.
       auto sum = sumAttr.cast<ElementsAttr>();
       const auto numElems = static_cast<size_t>(sum.getNumElements());
@@ -328,7 +328,7 @@ void arith::AddUIExtendedOp::getCanonicalizationPatterns(
 // SubIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
   // subi(x,x) -> 0
   if (getOperand(0) == getOperand(1))
     return Builder(getContext()).getZeroAttr(getType());
@@ -346,7 +346,8 @@ OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
   }
 
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
+      adaptor.getOperands(),
+      [](APInt a, const APInt &b) { return std::move(a) - b; });
 }
 
 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -360,7 +361,7 @@ void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 // MulIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
   // muli(x, 0) -> 0
   if (matchPattern(getRhs(), m_Zero()))
     return getRhs();
@@ -371,7 +372,8 @@ OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
 
   // default folder
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [](const APInt &a, const APInt &b) { return a * b; });
+      adaptor.getOperands(),
+      [](const APInt &a, const APInt &b) { return a * b; });
 }
 
 //===----------------------------------------------------------------------===//
@@ -386,11 +388,11 @@ arith::MulSIExtendedOp::getShapeForUnroll() {
 }
 
 LogicalResult
-arith::MulSIExtendedOp::fold(ArrayRef<Attribute> operands,
+arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
                              SmallVectorImpl<OpFoldResult> &results) {
   // mulsi_extended(x, 0) -> 0, 0
   if (matchPattern(getRhs(), m_Zero())) {
-    Attribute zero = operands[1];
+    Attribute zero = adaptor.getRhs();
     results.push_back(zero);
     results.push_back(zero);
     return success();
@@ -398,10 +400,11 @@ arith::MulSIExtendedOp::fold(ArrayRef<Attribute> operands,
 
   // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
   if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
-          operands, [](const APInt &a, const APInt &b) { return a * b; })) {
+          adaptor.getOperands(),
+          [](const APInt &a, const APInt &b) { return a * b; })) {
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
-        operands, [](const APInt &a, const APInt &b) {
+        adaptor.getOperands(), [](const APInt &a, const APInt &b) {
           unsigned bitWidth = a.getBitWidth();
           APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
           return fullProduct.extractBits(bitWidth, bitWidth);
@@ -433,11 +436,11 @@ arith::MulUIExtendedOp::getShapeForUnroll() {
 }
 
 LogicalResult
-arith::MulUIExtendedOp::fold(ArrayRef<Attribute> operands,
+arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
                              SmallVectorImpl<OpFoldResult> &results) {
   // mului_extended(x, 0) -> 0, 0
   if (matchPattern(getRhs(), m_Zero())) {
-    Attribute zero = operands[1];
+    Attribute zero = adaptor.getRhs();
     results.push_back(zero);
     results.push_back(zero);
     return success();
@@ -454,10 +457,11 @@ arith::MulUIExtendedOp::fold(ArrayRef<Attribute> operands,
 
   // mului_extended(cst_a, cst_b) -> cst_low, cst_high
   if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
-          operands, [](const APInt &a, const APInt &b) { return a * b; })) {
+          adaptor.getOperands(),
+          [](const APInt &a, const APInt &b) { return a * b; })) {
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
-        operands, [](const APInt &a, const APInt &b) {
+        adaptor.getOperands(), [](const APInt &a, const APInt &b) {
           unsigned bitWidth = a.getBitWidth();
           APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
           return fullProduct.extractBits(bitWidth, bitWidth);
@@ -481,21 +485,21 @@ void arith::MulUIExtendedOp::getCanonicalizationPatterns(
 // DivUIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
   // divui (x, 1) -> x.
   if (matchPattern(getRhs(), m_One()))
     return getLhs();
 
   // Don't fold if it would require a division by zero.
   bool div0 = false;
-  auto result =
-      constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
-        if (div0 || !b) {
-          div0 = true;
-          return a;
-        }
-        return a.udiv(b);
-      });
+  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+                                               [&](APInt a, const APInt &b) {
+                                                 if (div0 || !b) {
+                                                   div0 = true;
+                                                   return a;
+                                                 }
+                                                 return a.udiv(b);
+                                               });
 
   return div0 ? Attribute() : result;
 }
@@ -510,15 +514,15 @@ Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
 // DivSIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
   // divsi (x, 1) -> x.
   if (matchPattern(getRhs(), m_One()))
     return getLhs();
 
   // Don't fold if it would overflow or if it requires a division by zero.
   bool overflowOrDiv0 = false;
-  auto result =
-      constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
+  auto result = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](APInt a, const APInt &b) {
         if (overflowOrDiv0 || !b) {
           overflowOrDiv0 = true;
           return a;
@@ -557,14 +561,14 @@ static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
 // CeilDivUIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
   // ceildivui (x, 1) -> x.
   if (matchPattern(getRhs(), m_One()))
     return getLhs();
 
   bool overflowOrDiv0 = false;
-  auto result =
-      constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
+  auto result = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](APInt a, const APInt &b) {
         if (overflowOrDiv0 || !b) {
           overflowOrDiv0 = true;
           return a;
@@ -589,15 +593,15 @@ Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
 // CeilDivSIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
   // ceildivsi (x, 1) -> x.
   if (matchPattern(getRhs(), m_One()))
     return getLhs();
 
   // Don't fold if it would overflow or if it requires a division by zero.
   bool overflowOrDiv0 = false;
-  auto result =
-      constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
+  auto result = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](APInt a, const APInt &b) {
         if (overflowOrDiv0 || !b) {
           overflowOrDiv0 = true;
           return a;
@@ -650,15 +654,15 @@ Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
 // FloorDivSIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
   // floordivsi (x, 1) -> x.
   if (matchPattern(getRhs(), m_One()))
     return getLhs();
 
   // Don't fold if it would overflow or if it requires a division by zero.
   bool overflowOrDiv0 = false;
-  auto result =
-      constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
+  auto result = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](APInt a, const APInt &b) {
         if (overflowOrDiv0 || !b) {
           overflowOrDiv0 = true;
           return a;
@@ -699,21 +703,21 @@ OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
 // RemUIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
   // remui (x, 1) -> 0.
   if (matchPattern(getRhs(), m_One()))
     return Builder(getContext()).getZeroAttr(getType());
 
   // Don't fold if it would require a division by zero.
   bool div0 = false;
-  auto result =
-      constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
-        if (div0 || b.isNullValue()) {
-          div0 = true;
-          return a;
-        }
-        return a.urem(b);
-      });
+  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+                                               [&](APInt a, const APInt &b) {
+                                                 if (div0 || b.isNullValue()) {
+                                                   div0 = true;
+                                                   return a;
+                                                 }
+                                                 return a.urem(b);
+                                               });
 
   return div0 ? Attribute() : result;
 }
@@ -722,21 +726,21 @@ OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
 // RemSIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
   // remsi (x, 1) -> 0.
   if (matchPattern(getRhs(), m_One()))
     return Builder(getContext()).getZeroAttr(getType());
 
   // Don't fold if it would require a division by zero.
   bool div0 = false;
-  auto result =
-      constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
-        if (div0 || b.isNullValue()) {
-          div0 = true;
-          return a;
-        }
-        return a.srem(b);
-      });
+  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+                                               [&](APInt a, const APInt &b) {
+                                                 if (div0 || b.isNullValue()) {
+                                                   div0 = true;
+                                                   return a;
+                                                 }
+                                                 return a.srem(b);
+                                               });
 
   return div0 ? Attribute() : result;
 }
@@ -762,7 +766,7 @@ static Value foldAndIofAndI(arith::AndIOp op) {
   return {};
 }
 
-OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
   /// and(x, 0) -> 0
   if (matchPattern(getRhs(), m_Zero()))
     return getRhs();
@@ -786,31 +790,33 @@ OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
     return result;
 
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
+      adaptor.getOperands(),
+      [](APInt a, const APInt &b) { return std::move(a) & b; });
 }
 
 //===----------------------------------------------------------------------===//
 // OrIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
   /// or(x, 0) -> x
   if (matchPattern(getRhs(), m_Zero()))
     return getLhs();
   /// or(x, <all ones>) -> <all ones>
-  if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
+  if (auto rhsAttr = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>())
     if (rhsAttr.getValue().isAllOnes())
       return rhsAttr;
 
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
+      adaptor.getOperands(),
+      [](APInt a, const APInt &b) { return std::move(a) | b; });
 }
 
 //===----------------------------------------------------------------------===//
 // XOrIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
   /// xor(x, 0) -> x
   if (matchPattern(getRhs(), m_Zero()))
     return getLhs();
@@ -835,7 +841,8 @@ OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
   }
 
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
+      adaptor.getOperands(),
+      [](APInt a, const APInt &b) { return std::move(a) ^ b; });
 }
 
 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -847,11 +854,11 @@ void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 // NegFOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
   /// negf(negf(x)) -> x
   if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
     return op.getOperand();
-  return constFoldUnaryOp<FloatAttr>(operands,
+  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
                                      [](const APFloat &a) { return -a; });
 }
 
@@ -859,35 +866,35 @@ OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
 // AddFOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
   // addf(x, -0) -> x
   if (matchPattern(getRhs(), m_NegZeroFloat()))
     return getLhs();
 
   return constFoldBinaryOp<FloatAttr>(
-      operands, [](const APFloat &a, const APFloat &b) { return a + b; });
+      adaptor.getOperands(),
+      [](const APFloat &a, const APFloat &b) { return a + b; });
 }
 
 //===----------------------------------------------------------------------===//
 // SubFOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
   // subf(x, +0) -> x
   if (matchPattern(getRhs(), m_PosZeroFloat()))
     return getLhs();
 
   return constFoldBinaryOp<FloatAttr>(
-      operands, [](const APFloat &a, const APFloat &b) { return a - b; });
+      adaptor.getOperands(),
+      [](const APFloat &a, const APFloat &b) { return a - b; });
 }
 
 //===----------------------------------------------------------------------===//
 // MaxFOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "maxf takes two operands");
-
+OpFoldResult arith::MaxFOp::fold(FoldAdaptor adaptor) {
   // maxf(x,x) -> x
   if (getLhs() == getRhs())
     return getRhs();
@@ -897,7 +904,7 @@ OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
     return getLhs();
 
   return constFoldBinaryOp<FloatAttr>(
-      operands,
+      adaptor.getOperands(),
       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
 }
 
@@ -905,9 +912,7 @@ OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
 // MaxSIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "binary operation takes two operands");
-
+OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
   // maxsi(x,x) -> x
   if (getLhs() == getRhs())
     return getRhs();
@@ -923,7 +928,7 @@ OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
       intValue.isMinSignedValue())
     return getLhs();
 
-  return constFoldBinaryOp<IntegerAttr>(operands,
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
                                         [](const APInt &a, const APInt &b) {
                                           return llvm::APIntOps::smax(a, b);
                                         });
@@ -933,9 +938,7 @@ OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
 // MaxUIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "binary operation takes two operands");
-
+OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
   // maxui(x,x) -> x
   if (getLhs() == getRhs())
     return getRhs();
@@ -949,7 +952,7 @@ OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
     return getLhs();
 
-  return constFoldBinaryOp<IntegerAttr>(operands,
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
                                         [](const APInt &a, const APInt &b) {
                                           return llvm::APIntOps::umax(a, b);
                                         });
@@ -959,9 +962,7 @@ OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
 // MinFOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "minf takes two operands");
-
+OpFoldResult arith::MinFOp::fold(FoldAdaptor adaptor) {
   // minf(x,x) -> x
   if (getLhs() == getRhs())
     return getRhs();
@@ -971,7 +972,7 @@ OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
     return getLhs();
 
   return constFoldBinaryOp<FloatAttr>(
-      operands,
+      adaptor.getOperands(),
       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
 }
 
@@ -979,9 +980,7 @@ OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
 // MinSIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "binary operation takes two operands");
-
+OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
   // minsi(x,x) -> x
   if (getLhs() == getRhs())
     return getRhs();
@@ -997,7 +996,7 @@ OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
       intValue.isMaxSignedValue())
     return getLhs();
 
-  return constFoldBinaryOp<IntegerAttr>(operands,
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
                                         [](const APInt &a, const APInt &b) {
                                           return llvm::APIntOps::smin(a, b);
                                         });
@@ -1007,9 +1006,7 @@ OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
 // MinUIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "binary operation takes two operands");
-
+OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
   // minui(x,x) -> x
   if (getLhs() == getRhs())
     return getRhs();
@@ -1023,7 +1020,7 @@ OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
     return getLhs();
 
-  return constFoldBinaryOp<IntegerAttr>(operands,
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
                                         [](const APInt &a, const APInt &b) {
                                           return llvm::APIntOps::umin(a, b);
                                         });
@@ -1033,13 +1030,14 @@ OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
 // MulFOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
   // mulf(x, 1) -> x
   if (matchPattern(getRhs(), m_OneFloat()))
     return getLhs();
 
   return constFoldBinaryOp<FloatAttr>(
-      operands, [](const APFloat &a, const APFloat &b) { return a * b; });
+      adaptor.getOperands(),
+      [](const APFloat &a, const APFloat &b) { return a * b; });
 }
 
 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -1051,13 +1049,14 @@ void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 // DivFOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
   // divf(x, 1) -> x
   if (matchPattern(getRhs(), m_OneFloat()))
     return getLhs();
 
   return constFoldBinaryOp<FloatAttr>(
-      operands, [](const APFloat &a, const APFloat &b) { return a / b; });
+      adaptor.getOperands(),
+      [](const APFloat &a, const APFloat &b) { return a / b; });
 }
 
 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -1069,8 +1068,8 @@ void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 // RemFOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::RemFOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldBinaryOp<FloatAttr>(operands,
+OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
+  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
                                       [](const APFloat &a, const APFloat &b) {
                                         APFloat result(a);
                                         (void)result.remainder(b);
@@ -1170,7 +1169,7 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
 // ExtUIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
     getInMutable().assign(lhs.getIn());
     return getResult();
@@ -1179,7 +1178,8 @@ OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
   Type resType = getElementTypeOrSelf(getType());
   unsigned bitWidth = resType.cast<IntegerType>().getWidth();
   return constFoldCastOp<IntegerAttr, IntegerAttr>(
-      operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
+      adaptor.getOperands(), getType(),
+      [bitWidth](const APInt &a, bool &castStatus) {
         return a.zext(bitWidth);
       });
 }
@@ -1196,7 +1196,7 @@ LogicalResult arith::ExtUIOp::verify() {
 // ExtSIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
     getInMutable().assign(lhs.getIn());
     return getResult();
@@ -1205,7 +1205,8 @@ OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
   Type resType = getElementTypeOrSelf(getType());
   unsigned bitWidth = resType.cast<IntegerType>().getWidth();
   return constFoldCastOp<IntegerAttr, IntegerAttr>(
-      operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
+      adaptor.getOperands(), getType(),
+      [bitWidth](const APInt &a, bool &castStatus) {
         return a.sext(bitWidth);
       });
 }
@@ -1237,9 +1238,7 @@ LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
 // TruncIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 1 && "unary operation takes one operand");
-
+OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
   // trunci(zexti(a)) -> a
   // trunci(sexti(a)) -> a
   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
@@ -1255,7 +1254,8 @@ OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
   Type resType = getElementTypeOrSelf(getType());
   unsigned bitWidth = resType.cast<IntegerType>().getWidth();
   return constFoldCastOp<IntegerAttr, IntegerAttr>(
-      operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
+      adaptor.getOperands(), getType(),
+      [bitWidth](const APInt &a, bool &castStatus) {
         return a.trunc(bitWidth);
       });
 }
@@ -1280,10 +1280,8 @@ LogicalResult arith::TruncIOp::verify() {
 
 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
 /// can be represented without precision loss or rounding.
-OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 1 && "unary operation takes one operand");
-
-  auto constOperand = operands.front();
+OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
+  auto constOperand = adaptor.getIn();
   if (!constOperand || !constOperand.isa<FloatAttr>())
     return {};
 
@@ -1348,10 +1346,11 @@ bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
 }
 
-OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
   Type resEleType = getElementTypeOrSelf(getType());
   return constFoldCastOp<IntegerAttr, FloatAttr>(
-      operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
+      adaptor.getOperands(), getType(),
+      [&resEleType](const APInt &a, bool &castStatus) {
         FloatType floatTy = resEleType.cast<FloatType>();
         APFloat apf(floatTy.getFloatSemantics(),
                     APInt::getZero(floatTy.getWidth()));
@@ -1369,10 +1368,11 @@ bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
 }
 
-OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
   Type resEleType = getElementTypeOrSelf(getType());
   return constFoldCastOp<IntegerAttr, FloatAttr>(
-      operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
+      adaptor.getOperands(), getType(),
+      [&resEleType](const APInt &a, bool &castStatus) {
         FloatType floatTy = resEleType.cast<FloatType>();
         APFloat apf(floatTy.getFloatSemantics(),
                     APInt::getZero(floatTy.getWidth()));
@@ -1389,11 +1389,12 @@ bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
 }
 
-OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
   Type resType = getElementTypeOrSelf(getType());
   unsigned bitWidth = resType.cast<IntegerType>().getWidth();
   return constFoldCastOp<FloatAttr, IntegerAttr>(
-      operands, getType(), [&bitWidth](const APFloat &a, bool &castStatus) {
+      adaptor.getOperands(), getType(),
+      [&bitWidth](const APFloat &a, bool &castStatus) {
         bool ignored;
         APSInt api(bitWidth, /*isUnsigned=*/true);
         castStatus = APFloat::opInvalidOp !=
@@ -1410,11 +1411,12 @@ bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
 }
 
-OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
   Type resType = getElementTypeOrSelf(getType());
   unsigned bitWidth = resType.cast<IntegerType>().getWidth();
   return constFoldCastOp<FloatAttr, IntegerAttr>(
-      operands, getType(), [&bitWidth](const APFloat &a, bool &castStatus) {
+      adaptor.getOperands(), getType(),
+      [&bitWidth](const APFloat &a, bool &castStatus) {
         bool ignored;
         APSInt api(bitWidth, /*isUnsigned=*/false);
         castStatus = APFloat::opInvalidOp !=
@@ -1445,11 +1447,11 @@ bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
   return areIndexCastCompatible(inputs, outputs);
 }
 
-OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
   // index_cast(constant) -> constant
   // A little hack because we go through int. Otherwise, the size of the
   // constant might need to change.
-  if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
+  if (auto value = adaptor.getIn().dyn_cast_or_null<IntegerAttr>())
     return IntegerAttr::get(getType(), value.getInt());
 
   return {};
@@ -1469,11 +1471,11 @@ bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
   return areIndexCastCompatible(inputs, outputs);
 }
 
-OpFoldResult arith::IndexCastUIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
   // index_castui(constant) -> constant
   // A little hack because we go through int. Otherwise, the size of the
   // constant might need to change.
-  if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
+  if (auto value = adaptor.getIn().dyn_cast_or_null<IntegerAttr>())
     return IntegerAttr::get(getType(), value.getValue().getZExtValue());
 
   return {};
@@ -1502,11 +1504,9 @@ bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
 }
 
-OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 1 && "bitcast op expects 1 operand");
-
+OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
   auto resType = getType();
-  auto operand = operands[0];
+  auto operand = adaptor.getIn();
   if (!operand)
     return {};
 
@@ -1620,9 +1620,7 @@ static std::optional<int64_t> getIntegerWidth(Type t) {
   return std::nullopt;
 }
 
-OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "cmpi takes two operands");
-
+OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
   // cmpi(pred, x, x)
   if (getLhs() == getRhs()) {
     auto val = applyCmpPredicateToEqualOperands(getPredicate());
@@ -1649,7 +1647,7 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
   }
 
   // Move constant to the right side.
-  if (operands[0] && !operands[1]) {
+  if (adaptor.getLhs() && !adaptor.getRhs()) {
     // Do not use invertPredicate, as it will change eq to ne and vice versa.
     using Pred = CmpIPredicate;
     const std::pair<Pred, Pred> invPreds[] = {
@@ -1672,13 +1670,13 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
     llvm_unreachable("unknown cmpi predicate kind");
   }
 
-  auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
+  auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
   if (!lhs)
     return {};
 
   // We are moving constants to the right side; So if lhs is constant rhs is
   // guaranteed to be a constant.
-  auto rhs = operands.back().cast<IntegerAttr>();
+  auto rhs = adaptor.getRhs().cast<IntegerAttr>();
 
   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
   return BoolAttr::get(getContext(), val);
@@ -1741,11 +1739,9 @@ bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
   llvm_unreachable("unknown cmpf predicate kind");
 }
 
-OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "cmpf takes two operands");
-
-  auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
-  auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
+OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
+  auto lhs = adaptor.getLhs().dyn_cast_or_null<FloatAttr>();
+  auto rhs = adaptor.getRhs().dyn_cast_or_null<FloatAttr>();
 
   // If one operand is NaN, making them both NaN does not change the result.
   if (lhs && lhs.getValue().isNaN())
@@ -2123,7 +2119,7 @@ void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<SelectI1Simplify, SelectToExtUI>(context);
 }
 
-OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
   Value trueVal = getTrueValue();
   Value falseVal = getFalseValue();
   if (trueVal == falseVal)
@@ -2220,14 +2216,14 @@ LogicalResult arith::SelectOp::verify() {
 // ShLIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
   // shli(x, 0) -> x
   if (matchPattern(getRhs(), m_Zero()))
     return getLhs();
   // Don't fold if shifting more than the bit width.
   bool bounded = false;
   auto result = constFoldBinaryOp<IntegerAttr>(
-      operands, [&](const APInt &a, const APInt &b) {
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
         bounded = b.ule(b.getBitWidth());
         return a.shl(b);
       });
@@ -2238,14 +2234,14 @@ OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
 // ShRUIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
   // shrui(x, 0) -> x
   if (matchPattern(getRhs(), m_Zero()))
     return getLhs();
   // Don't fold if shifting more than the bit width.
   bool bounded = false;
   auto result = constFoldBinaryOp<IntegerAttr>(
-      operands, [&](const APInt &a, const APInt &b) {
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
         bounded = b.ule(b.getBitWidth());
         return a.lshr(b);
       });
@@ -2256,14 +2252,14 @@ OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
 // ShRSIOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
   // shrsi(x, 0) -> x
   if (matchPattern(getRhs(), m_Zero()))
     return getLhs();
   // Don't fold if shifting more than the bit width.
   bool bounded = false;
   auto result = constFoldBinaryOp<IntegerAttr>(
-      operands, [&](const APInt &a, const APInt &b) {
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
         bounded = b.ule(b.getBitWidth());
         return a.ashr(b);
       });


        


More information about the Mlir-commits mailing list