[Mlir-commits] [mlir] 6c30503 - [mlir][math][NFC] Migrate math dialect to the new fold API

Markus Böck llvmlistbot at llvm.org
Wed Jan 11 09:12:13 PST 2023


Author: Markus Böck
Date: 2023-01-11T18:11:46+01:00
New Revision: 6c30503ef8b41625cdf705fbea5eb0dacdc2c0ae

URL: https://github.com/llvm/llvm-project/commit/6c30503ef8b41625cdf705fbea5eb0dacdc2c0ae
DIFF: https://github.com/llvm/llvm-project/commit/6c30503ef8b41625cdf705fbea5eb0dacdc2c0ae.diff

LOG: [mlir][math][NFC] Migrate math dialect to the new fold API

See https://discourse.llvm.org/t/psa-new-improved-fold-method-signature-has-landed-please-update-your-downstream-projects/67618 for context

Similar to the patch for the arith dialect, the math dialects fold implementations make heavy use of generic fold functions, hence the change being comparatively mechanical and mostly changing the function signature.

Differential Revision: https://reviews.llvm.org/D141500

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Math/IR/MathBase.td
    mlir/lib/Dialect/Math/IR/MathOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/IR/MathBase.td b/mlir/include/mlir/Dialect/Math/IR/MathBase.td
index 0189fd538b496..e63db4ca5db7f 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathBase.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathBase.td
@@ -30,5 +30,6 @@ def Math_Dialect : Dialect {
     ```
   }];
   let hasConstantMaterializer = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 #endif // MATH_BASE

diff  --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 621bfa5ea12ae..78186ba48cbd1 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -25,8 +25,8 @@ using namespace mlir::math;
 // AbsFOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::AbsFOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldUnaryOp<FloatAttr>(operands,
+OpFoldResult math::AbsFOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
                                      [](const APFloat &a) { return abs(a); });
 }
 
@@ -34,8 +34,8 @@ OpFoldResult math::AbsFOp::fold(ArrayRef<Attribute> operands) {
 // AbsIOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::AbsIOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldUnaryOp<IntegerAttr>(operands,
+OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
                                        [](const APInt &a) { return a.abs(); });
 }
 
@@ -43,9 +43,9 @@ OpFoldResult math::AbsIOp::fold(ArrayRef<Attribute> operands) {
 // AtanOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::AtanOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(atan(a.convertToDouble()));
@@ -61,9 +61,10 @@ OpFoldResult math::AtanOp::fold(ArrayRef<Attribute> operands) {
 // Atan2Op folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::Atan2Op::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) {
   return constFoldBinaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a, const APFloat &b) -> Optional<APFloat> {
+      adaptor.getOperands(),
+      [](const APFloat &a, const APFloat &b) -> Optional<APFloat> {
         if (a.isZero() && b.isZero())
           return llvm::APFloat::getNaN(a.getSemantics());
 
@@ -83,20 +84,21 @@ OpFoldResult math::Atan2Op::fold(ArrayRef<Attribute> operands) {
 // CeilOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
-    APFloat result(a);
-    result.roundToIntegral(llvm::RoundingMode::TowardPositive);
-    return result;
-  });
+OpFoldResult math::CeilOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOp<FloatAttr>(
+      adaptor.getOperands(), [](const APFloat &a) {
+        APFloat result(a);
+        result.roundToIntegral(llvm::RoundingMode::TowardPositive);
+        return result;
+      });
 }
 
 //===----------------------------------------------------------------------===//
 // CopySignOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldBinaryOp<FloatAttr>(operands,
+OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
+  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
                                       [](const APFloat &a, const APFloat &b) {
                                         APFloat result(a);
                                         result.copySign(b);
@@ -108,9 +110,9 @@ OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) {
 // CosOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::CosOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(cos(a.convertToDouble()));
@@ -126,9 +128,9 @@ OpFoldResult math::CosOp::fold(ArrayRef<Attribute> operands) {
 // SinOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::SinOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(sin(a.convertToDouble()));
@@ -144,39 +146,42 @@ OpFoldResult math::SinOp::fold(ArrayRef<Attribute> operands) {
 // CountLeadingZerosOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
-    return APInt(a.getBitWidth(), a.countLeadingZeros());
-  });
+OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOp<IntegerAttr>(
+      adaptor.getOperands(), [](const APInt &a) {
+        return APInt(a.getBitWidth(), a.countLeadingZeros());
+      });
 }
 
 //===----------------------------------------------------------------------===//
 // CountTrailingZerosOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
-    return APInt(a.getBitWidth(), a.countTrailingZeros());
-  });
+OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOp<IntegerAttr>(
+      adaptor.getOperands(), [](const APInt &a) {
+        return APInt(a.getBitWidth(), a.countTrailingZeros());
+      });
 }
 
 //===----------------------------------------------------------------------===//
 // CtPopOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::CtPopOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
-    return APInt(a.getBitWidth(), a.countPopulation());
-  });
+OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOp<IntegerAttr>(
+      adaptor.getOperands(), [](const APInt &a) {
+        return APInt(a.getBitWidth(), a.countPopulation());
+      });
 }
 
 //===----------------------------------------------------------------------===//
 // ErfOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::ErfOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(erf(a.convertToDouble()));
@@ -192,9 +197,10 @@ OpFoldResult math::ErfOp::fold(ArrayRef<Attribute> operands) {
 // IPowIOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::IPowIOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::IPowIOp::fold(FoldAdaptor adaptor) {
   return constFoldBinaryOpConditional<IntegerAttr>(
-      operands, [](const APInt &base, const APInt &power) -> Optional<APInt> {
+      adaptor.getOperands(),
+      [](const APInt &base, const APInt &power) -> Optional<APInt> {
         unsigned width = base.getBitWidth();
         auto zeroValue = APInt::getZero(width);
         APInt oneValue{width, 1ULL, /*isSigned=*/true};
@@ -242,9 +248,9 @@ OpFoldResult math::IPowIOp::fold(ArrayRef<Attribute> operands) {
 // LogOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::LogOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         if (a.isNegative())
           return {};
 
@@ -262,9 +268,9 @@ OpFoldResult math::LogOp::fold(ArrayRef<Attribute> operands) {
 // Log2Op folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         if (a.isNegative())
           return {};
 
@@ -282,9 +288,9 @@ OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
 // Log10Op folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::Log10Op::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         if (a.isNegative())
           return {};
 
@@ -303,9 +309,9 @@ OpFoldResult math::Log10Op::fold(ArrayRef<Attribute> operands) {
 // Log1pOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::Log1pOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           if ((a + APFloat(1.0)).isNegative())
@@ -325,9 +331,10 @@ OpFoldResult math::Log1pOp::fold(ArrayRef<Attribute> operands) {
 // PowFOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::PowFOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
   return constFoldBinaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a, const APFloat &b) -> Optional<APFloat> {
+      adaptor.getOperands(),
+      [](const APFloat &a, const APFloat &b) -> Optional<APFloat> {
         if (a.getSizeInBits(a.getSemantics()) == 64 &&
             b.getSizeInBits(b.getSemantics()) == 64)
           return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
@@ -344,9 +351,9 @@ OpFoldResult math::PowFOp::fold(ArrayRef<Attribute> operands) {
 // SqrtOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::SqrtOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         if (a.isNegative())
           return {};
 
@@ -365,9 +372,9 @@ OpFoldResult math::SqrtOp::fold(ArrayRef<Attribute> operands) {
 // ExpOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::ExpOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(exp(a.convertToDouble()));
@@ -383,9 +390,9 @@ OpFoldResult math::ExpOp::fold(ArrayRef<Attribute> operands) {
 // Exp2Op folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::Exp2Op::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(exp2(a.convertToDouble()));
@@ -401,9 +408,9 @@ OpFoldResult math::Exp2Op::fold(ArrayRef<Attribute> operands) {
 // ExpM1Op folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::ExpM1Op::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(expm1(a.convertToDouble()));
@@ -419,9 +426,9 @@ OpFoldResult math::ExpM1Op::fold(ArrayRef<Attribute> operands) {
 // TanOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::TanOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(tan(a.convertToDouble()));
@@ -437,9 +444,9 @@ OpFoldResult math::TanOp::fold(ArrayRef<Attribute> operands) {
 // TanhOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::TanhOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(tanh(a.convertToDouble()));
@@ -455,33 +462,35 @@ OpFoldResult math::TanhOp::fold(ArrayRef<Attribute> operands) {
 // RoundEvenOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::RoundEvenOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
-    APFloat result(a);
-    result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
-    return result;
-  });
+OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOp<FloatAttr>(
+      adaptor.getOperands(), [](const APFloat &a) {
+        APFloat result(a);
+        result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
+        return result;
+      });
 }
 
 //===----------------------------------------------------------------------===//
 // FloorOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::FloorOp::fold(ArrayRef<Attribute> operands) {
-  return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
-    APFloat result(a);
-    result.roundToIntegral(llvm::RoundingMode::TowardNegative);
-    return result;
-  });
+OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOp<FloatAttr>(
+      adaptor.getOperands(), [](const APFloat &a) {
+        APFloat result(a);
+        result.roundToIntegral(llvm::RoundingMode::TowardNegative);
+        return result;
+      });
 }
 
 //===----------------------------------------------------------------------===//
 // RoundOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::RoundOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(round(a.convertToDouble()));
@@ -497,9 +506,9 @@ OpFoldResult math::RoundOp::fold(ArrayRef<Attribute> operands) {
 // TruncOp folder
 //===----------------------------------------------------------------------===//
 
-OpFoldResult math::TruncOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
-      operands, [](const APFloat &a) -> Optional<APFloat> {
+      adaptor.getOperands(), [](const APFloat &a) -> Optional<APFloat> {
         switch (a.getSizeInBits(a.getSemantics())) {
         case 64:
           return APFloat(trunc(a.convertToDouble()));


        


More information about the Mlir-commits mailing list