[Mlir-commits] [mlir] 83bd4fe - [mlir][Math] Replace some constant folder functions with common folder functions.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 12 04:35:09 PDT 2022


Author: jacquesguan
Date: 2022-04-12T11:34:47Z
New Revision: 83bd4fe2e83c66796afaeb18386de249133c6732

URL: https://github.com/llvm/llvm-project/commit/83bd4fe2e83c66796afaeb18386de249133c6732
DIFF: https://github.com/llvm/llvm-project/commit/83bd4fe2e83c66796afaeb18386de249133c6732.diff

LOG: [mlir][Math] Replace some constant folder functions with common folder functions.

Differential Revision: https://reviews.llvm.org/D123485

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/IR/MathOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 28f42f814f6dc..036d8d9423327 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/CommonFolders.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/IR/Builders.h"
 
@@ -25,25 +26,10 @@ using namespace mlir::math;
 //===----------------------------------------------------------------------===//
 
 OpFoldResult math::AbsOp::fold(ArrayRef<Attribute> operands) {
-  auto constOperand = operands.front();
-  if (!constOperand)
-    return {};
-
-  auto attr = constOperand.dyn_cast<FloatAttr>();
-  if (!attr)
-    return {};
-
-  auto ft = getType().cast<FloatType>();
-
-  APFloat apf = attr.getValue();
-
-  if (ft.getWidth() == 64)
-    return FloatAttr::get(getType(), fabs(apf.convertToDouble()));
-
-  if (ft.getWidth() == 32)
-    return FloatAttr::get(getType(), fabsf(apf.convertToFloat()));
-
-  return {};
+  return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
+    APFloat result(a);
+    return abs(result);
+  });
 }
 
 //===----------------------------------------------------------------------===//
@@ -51,18 +37,11 @@ OpFoldResult math::AbsOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
-  auto constOperand = operands.front();
-  if (!constOperand)
-    return {};
-
-  auto attr = constOperand.dyn_cast<FloatAttr>();
-  if (!attr)
-    return {};
-
-  APFloat sourceVal = attr.getValue();
-  sourceVal.roundToIntegral(llvm::RoundingMode::TowardPositive);
-
-  return FloatAttr::get(getType(), sourceVal);
+  return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
+    APFloat result(a);
+    result.roundToIntegral(llvm::RoundingMode::TowardPositive);
+    return result;
+  });
 }
 
 //===----------------------------------------------------------------------===//
@@ -70,26 +49,12 @@ OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) {
-  auto ft = getType().dyn_cast<FloatType>();
-  if (!ft)
-    return {};
-
-  APFloat vals[2]{APFloat(ft.getFloatSemantics()),
-                  APFloat(ft.getFloatSemantics())};
-  for (int i = 0; i < 2; ++i) {
-    if (!operands[i])
-      return {};
-
-    auto attr = operands[i].dyn_cast<FloatAttr>();
-    if (!attr)
-      return {};
-
-    vals[i] = attr.getValue();
-  }
-
-  vals[0].copySign(vals[1]);
-
-  return FloatAttr::get(getType(), vals[0]);
+  return constFoldBinaryOp<FloatAttr>(operands,
+                                      [](const APFloat &a, const APFloat &b) {
+                                        APFloat result(a);
+                                        result.copySign(b);
+                                        return result;
+                                      });
 }
 
 //===----------------------------------------------------------------------===//
@@ -97,15 +62,9 @@ OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef<Attribute> operands) {
-  auto constOperand = operands.front();
-  if (!constOperand)
-    return {};
-
-  auto attr = constOperand.dyn_cast<IntegerAttr>();
-  if (!attr)
-    return {};
-
-  return IntegerAttr::get(getType(), attr.getValue().countLeadingZeros());
+  return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
+    return APInt(a.getBitWidth(), a.countLeadingZeros());
+  });
 }
 
 //===----------------------------------------------------------------------===//
@@ -113,15 +72,9 @@ OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef<Attribute> operands) {
-  auto constOperand = operands.front();
-  if (!constOperand)
-    return {};
-
-  auto attr = constOperand.dyn_cast<IntegerAttr>();
-  if (!attr)
-    return {};
-
-  return IntegerAttr::get(getType(), attr.getValue().countTrailingZeros());
+  return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
+    return APInt(a.getBitWidth(), a.countTrailingZeros());
+  });
 }
 
 //===----------------------------------------------------------------------===//
@@ -129,15 +82,9 @@ OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult math::CtPopOp::fold(ArrayRef<Attribute> operands) {
-  auto constOperand = operands.front();
-  if (!constOperand)
-    return {};
-
-  auto attr = constOperand.dyn_cast<IntegerAttr>();
-  if (!attr)
-    return {};
-
-  return IntegerAttr::get(getType(), attr.getValue().countPopulation());
+  return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
+    return APInt(a.getBitWidth(), a.countPopulation());
+  });
 }
 
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list