[Mlir-commits] [mlir] 170f030 - [mlir][math] Use APFloat::SemanticsToEnum in constant folding (#193914)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 24 01:56:35 PDT 2026
Author: Longsheng Mou
Date: 2026-04-24T16:56:30+08:00
New Revision: 170f030c22c52bc5d840cd8f76756c9fba5ee816
URL: https://github.com/llvm/llvm-project/commit/170f030c22c52bc5d840cd8f76756c9fba5ee816
DIFF: https://github.com/llvm/llvm-project/commit/170f030c22c52bc5d840cd8f76756c9fba5ee816.diff
LOG: [mlir][math] Use APFloat::SemanticsToEnum in constant folding (#193914)
Refactor constant folding in the Math dialect to use APFloat::SemanticsToEnum() instead of getSizeInBits() when checking
floating-point semantics. Inferring semantics from bitwidth is fragile: different formats may share the same bit width but have distinct semantics, leading to incorrect dispatch. SemanticsToEnum() matches on the exact semantics descriptor, making the intent explicit and ensuring correct dispatch.
Added:
Modified:
mlir/lib/Dialect/Math/IR/MathOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 4c0274ddb18a1..bec95f58260be 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -62,10 +62,10 @@ OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(acos(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(acosf(a.convertToFloat()));
default:
return {};
@@ -80,10 +80,10 @@ OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(acosh(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(acoshf(a.convertToFloat()));
default:
return {};
@@ -98,10 +98,10 @@ OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(asin(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(asinf(a.convertToFloat()));
default:
return {};
@@ -116,10 +116,10 @@ OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(asinh(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(asinhf(a.convertToFloat()));
default:
return {};
@@ -134,10 +134,10 @@ OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(atan(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(atanf(a.convertToFloat()));
default:
return {};
@@ -152,10 +152,10 @@ OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::AtanhOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(atanh(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(atanhf(a.convertToFloat()));
default:
return {};
@@ -174,15 +174,14 @@ OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) {
if (a.isZero() && b.isZero())
return llvm::APFloat::getNaN(a.getSemantics());
- if (a.getSizeInBits(a.getSemantics()) == 64 &&
- b.getSizeInBits(b.getSemantics()) == 64)
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(atan2(a.convertToDouble(), b.convertToDouble()));
-
- if (a.getSizeInBits(a.getSemantics()) == 32 &&
- b.getSizeInBits(b.getSemantics()) == 32)
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(atan2f(a.convertToFloat(), b.convertToFloat()));
-
- return {};
+ default:
+ return {};
+ }
});
}
@@ -219,10 +218,10 @@ OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(cos(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(cosf(a.convertToFloat()));
default:
return {};
@@ -237,10 +236,10 @@ OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(cosh(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(coshf(a.convertToFloat()));
default:
return {};
@@ -255,10 +254,10 @@ OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(sin(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(sinf(a.convertToFloat()));
default:
return {};
@@ -273,10 +272,10 @@ OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(sinh(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(sinhf(a.convertToFloat()));
default:
return {};
@@ -331,10 +330,10 @@ OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(erf(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(erff(a.convertToFloat()));
default:
return {};
@@ -424,13 +423,14 @@ OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) {
if (a.isNegative())
return {};
- if (a.getSizeInBits(a.getSemantics()) == 64)
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(log(a.convertToDouble()));
-
- if (a.getSizeInBits(a.getSemantics()) == 32)
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(logf(a.convertToFloat()));
-
- return {};
+ default:
+ return {};
+ }
});
}
@@ -444,13 +444,14 @@ OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) {
if (a.isNegative())
return {};
- if (a.getSizeInBits(a.getSemantics()) == 64)
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(log2(a.convertToDouble()));
-
- if (a.getSizeInBits(a.getSemantics()) == 32)
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(log2f(a.convertToFloat()));
-
- return {};
+ default:
+ return {};
+ }
});
}
@@ -464,10 +465,10 @@ OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
if (a.isNegative())
return {};
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(log10(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(log10f(a.convertToFloat()));
default:
return {};
@@ -482,12 +483,12 @@ OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
if ((a + APFloat(1.0)).isNegative())
return {};
return APFloat(log1p(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
if ((a + APFloat(1.0f)).isNegative())
return {};
return APFloat(log1pf(a.convertToFloat()));
@@ -505,15 +506,14 @@ OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
return constFoldBinaryOpConditional<FloatAttr>(
adaptor.getOperands(),
[](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
- if (a.getSizeInBits(a.getSemantics()) == 64 &&
- b.getSizeInBits(b.getSemantics()) == 64)
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
-
- if (a.getSizeInBits(a.getSemantics()) == 32 &&
- b.getSizeInBits(b.getSemantics()) == 32)
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
-
- return {};
+ default:
+ return {};
+ }
});
}
@@ -528,10 +528,10 @@ OpFoldResult math::RsqrtOp::fold(FoldAdaptor adaptor) {
return {};
APFloat one(a.getSemantics(), 1);
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return one / APFloat(sqrt(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return one / APFloat(sqrtf(a.convertToFloat()));
default:
return {};
@@ -549,10 +549,10 @@ OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
if (a.isNegative())
return {};
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(sqrt(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(sqrtf(a.convertToFloat()));
default:
return {};
@@ -567,10 +567,10 @@ OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(exp(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(expf(a.convertToFloat()));
default:
return {};
@@ -585,10 +585,10 @@ OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(exp2(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(exp2f(a.convertToFloat()));
default:
return {};
@@ -603,10 +603,10 @@ OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(expm1(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(expm1f(a.convertToFloat()));
default:
return {};
@@ -685,10 +685,10 @@ OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(tan(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(tanf(a.convertToFloat()));
default:
return {};
@@ -703,10 +703,10 @@ OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(tanh(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(tanhf(a.convertToFloat()));
default:
return {};
@@ -747,10 +747,10 @@ OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(round(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(roundf(a.convertToFloat()));
default:
return {};
@@ -765,10 +765,10 @@ OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(trunc(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(truncf(a.convertToFloat()));
default:
return {};
More information about the Mlir-commits
mailing list