[Mlir-commits] [mlir] 83bd4fe - [mlir][Math] Replace some constant folder functions with common folder functions.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 12 04:35:09 PDT 2022
Author: jacquesguan
Date: 2022-04-12T11:34:47Z
New Revision: 83bd4fe2e83c66796afaeb18386de249133c6732
URL: https://github.com/llvm/llvm-project/commit/83bd4fe2e83c66796afaeb18386de249133c6732
DIFF: https://github.com/llvm/llvm-project/commit/83bd4fe2e83c66796afaeb18386de249133c6732.diff
LOG: [mlir][Math] Replace some constant folder functions with common folder functions.
Differential Revision: https://reviews.llvm.org/D123485
Added:
Modified:
mlir/lib/Dialect/Math/IR/MathOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 28f42f814f6dc..036d8d9423327 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/Builders.h"
@@ -25,25 +26,10 @@ using namespace mlir::math;
//===----------------------------------------------------------------------===//
OpFoldResult math::AbsOp::fold(ArrayRef<Attribute> operands) {
- auto constOperand = operands.front();
- if (!constOperand)
- return {};
-
- auto attr = constOperand.dyn_cast<FloatAttr>();
- if (!attr)
- return {};
-
- auto ft = getType().cast<FloatType>();
-
- APFloat apf = attr.getValue();
-
- if (ft.getWidth() == 64)
- return FloatAttr::get(getType(), fabs(apf.convertToDouble()));
-
- if (ft.getWidth() == 32)
- return FloatAttr::get(getType(), fabsf(apf.convertToFloat()));
-
- return {};
+ return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
+ APFloat result(a);
+ return abs(result);
+ });
}
//===----------------------------------------------------------------------===//
@@ -51,18 +37,11 @@ OpFoldResult math::AbsOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
- auto constOperand = operands.front();
- if (!constOperand)
- return {};
-
- auto attr = constOperand.dyn_cast<FloatAttr>();
- if (!attr)
- return {};
-
- APFloat sourceVal = attr.getValue();
- sourceVal.roundToIntegral(llvm::RoundingMode::TowardPositive);
-
- return FloatAttr::get(getType(), sourceVal);
+ return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
+ APFloat result(a);
+ result.roundToIntegral(llvm::RoundingMode::TowardPositive);
+ return result;
+ });
}
//===----------------------------------------------------------------------===//
@@ -70,26 +49,12 @@ OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) {
- auto ft = getType().dyn_cast<FloatType>();
- if (!ft)
- return {};
-
- APFloat vals[2]{APFloat(ft.getFloatSemantics()),
- APFloat(ft.getFloatSemantics())};
- for (int i = 0; i < 2; ++i) {
- if (!operands[i])
- return {};
-
- auto attr = operands[i].dyn_cast<FloatAttr>();
- if (!attr)
- return {};
-
- vals[i] = attr.getValue();
- }
-
- vals[0].copySign(vals[1]);
-
- return FloatAttr::get(getType(), vals[0]);
+ return constFoldBinaryOp<FloatAttr>(operands,
+ [](const APFloat &a, const APFloat &b) {
+ APFloat result(a);
+ result.copySign(b);
+ return result;
+ });
}
//===----------------------------------------------------------------------===//
@@ -97,15 +62,9 @@ OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef<Attribute> operands) {
- auto constOperand = operands.front();
- if (!constOperand)
- return {};
-
- auto attr = constOperand.dyn_cast<IntegerAttr>();
- if (!attr)
- return {};
-
- return IntegerAttr::get(getType(), attr.getValue().countLeadingZeros());
+ return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
+ return APInt(a.getBitWidth(), a.countLeadingZeros());
+ });
}
//===----------------------------------------------------------------------===//
@@ -113,15 +72,9 @@ OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef<Attribute> operands) {
- auto constOperand = operands.front();
- if (!constOperand)
- return {};
-
- auto attr = constOperand.dyn_cast<IntegerAttr>();
- if (!attr)
- return {};
-
- return IntegerAttr::get(getType(), attr.getValue().countTrailingZeros());
+ return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
+ return APInt(a.getBitWidth(), a.countTrailingZeros());
+ });
}
//===----------------------------------------------------------------------===//
@@ -129,15 +82,9 @@ OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
OpFoldResult math::CtPopOp::fold(ArrayRef<Attribute> operands) {
- auto constOperand = operands.front();
- if (!constOperand)
- return {};
-
- auto attr = constOperand.dyn_cast<IntegerAttr>();
- if (!attr)
- return {};
-
- return IntegerAttr::get(getType(), attr.getValue().countPopulation());
+ return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
+ return APInt(a.getBitWidth(), a.countPopulation());
+ });
}
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list