[Mlir-commits] [mlir] [mlir][math] Use APFloat::SemanticsToEnum in constant folding (PR #193914)
Longsheng Mou
llvmlistbot at llvm.org
Fri Apr 24 00:48:07 PDT 2026
https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/193914
Refactor constant folding in the Math dialect to use APFloat::SemanticsToEnum() instead of getSizeInBits() when checking floating-point semantics. Inferring semantics from bitwidth is fragile: different formats may share the same bit width but have distinct semantics, leading to incorrect dispatch. SemanticsToEnum() matches on the exact semantics descriptor, making the intent explicit and ensuring correct dispatch.
>From 925d95bdc6153db0d62651ab729a3219e6535c2d Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Fri, 24 Apr 2026 15:46:59 +0800
Subject: [PATCH] [mlir][math] Use APFloat::SemanticsToEnum in constant folding
Refactor constant folding in the Math dialect to use APFloat::SemanticsToEnum() instead of getSizeInBits() when checking floating-point semantics. Inferring semantics from bitwidth is fragile: different formats may share the same bit width but have distinct semantics, leading to incorrect dispatch. SemanticsToEnum() matches on the exact semantics descriptor, making the intent explicit and ensuring correct dispatch.
---
mlir/lib/Dialect/Math/IR/MathOps.cpp | 180 +++++++++++++--------------
1 file changed, 90 insertions(+), 90 deletions(-)
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 4c0274ddb18a1..bec95f58260be 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -62,10 +62,10 @@ OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(acos(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(acosf(a.convertToFloat()));
default:
return {};
@@ -80,10 +80,10 @@ OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(acosh(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(acoshf(a.convertToFloat()));
default:
return {};
@@ -98,10 +98,10 @@ OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(asin(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(asinf(a.convertToFloat()));
default:
return {};
@@ -116,10 +116,10 @@ OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(asinh(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(asinhf(a.convertToFloat()));
default:
return {};
@@ -134,10 +134,10 @@ OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(atan(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(atanf(a.convertToFloat()));
default:
return {};
@@ -152,10 +152,10 @@ OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::AtanhOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(atanh(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(atanhf(a.convertToFloat()));
default:
return {};
@@ -174,15 +174,14 @@ OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) {
if (a.isZero() && b.isZero())
return llvm::APFloat::getNaN(a.getSemantics());
- if (a.getSizeInBits(a.getSemantics()) == 64 &&
- b.getSizeInBits(b.getSemantics()) == 64)
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(atan2(a.convertToDouble(), b.convertToDouble()));
-
- if (a.getSizeInBits(a.getSemantics()) == 32 &&
- b.getSizeInBits(b.getSemantics()) == 32)
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(atan2f(a.convertToFloat(), b.convertToFloat()));
-
- return {};
+ default:
+ return {};
+ }
});
}
@@ -219,10 +218,10 @@ OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(cos(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(cosf(a.convertToFloat()));
default:
return {};
@@ -237,10 +236,10 @@ OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(cosh(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(coshf(a.convertToFloat()));
default:
return {};
@@ -255,10 +254,10 @@ OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(sin(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(sinf(a.convertToFloat()));
default:
return {};
@@ -273,10 +272,10 @@ OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(sinh(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(sinhf(a.convertToFloat()));
default:
return {};
@@ -331,10 +330,10 @@ OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(erf(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(erff(a.convertToFloat()));
default:
return {};
@@ -424,13 +423,14 @@ OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) {
if (a.isNegative())
return {};
- if (a.getSizeInBits(a.getSemantics()) == 64)
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(log(a.convertToDouble()));
-
- if (a.getSizeInBits(a.getSemantics()) == 32)
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(logf(a.convertToFloat()));
-
- return {};
+ default:
+ return {};
+ }
});
}
@@ -444,13 +444,14 @@ OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) {
if (a.isNegative())
return {};
- if (a.getSizeInBits(a.getSemantics()) == 64)
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(log2(a.convertToDouble()));
-
- if (a.getSizeInBits(a.getSemantics()) == 32)
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(log2f(a.convertToFloat()));
-
- return {};
+ default:
+ return {};
+ }
});
}
@@ -464,10 +465,10 @@ OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
if (a.isNegative())
return {};
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(log10(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(log10f(a.convertToFloat()));
default:
return {};
@@ -482,12 +483,12 @@ OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
if ((a + APFloat(1.0)).isNegative())
return {};
return APFloat(log1p(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
if ((a + APFloat(1.0f)).isNegative())
return {};
return APFloat(log1pf(a.convertToFloat()));
@@ -505,15 +506,14 @@ OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
return constFoldBinaryOpConditional<FloatAttr>(
adaptor.getOperands(),
[](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
- if (a.getSizeInBits(a.getSemantics()) == 64 &&
- b.getSizeInBits(b.getSemantics()) == 64)
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
-
- if (a.getSizeInBits(a.getSemantics()) == 32 &&
- b.getSizeInBits(b.getSemantics()) == 32)
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
-
- return {};
+ default:
+ return {};
+ }
});
}
@@ -528,10 +528,10 @@ OpFoldResult math::RsqrtOp::fold(FoldAdaptor adaptor) {
return {};
APFloat one(a.getSemantics(), 1);
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return one / APFloat(sqrt(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return one / APFloat(sqrtf(a.convertToFloat()));
default:
return {};
@@ -549,10 +549,10 @@ OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
if (a.isNegative())
return {};
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(sqrt(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(sqrtf(a.convertToFloat()));
default:
return {};
@@ -567,10 +567,10 @@ OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(exp(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(expf(a.convertToFloat()));
default:
return {};
@@ -585,10 +585,10 @@ OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(exp2(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(exp2f(a.convertToFloat()));
default:
return {};
@@ -603,10 +603,10 @@ OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(expm1(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(expm1f(a.convertToFloat()));
default:
return {};
@@ -685,10 +685,10 @@ OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(tan(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(tanf(a.convertToFloat()));
default:
return {};
@@ -703,10 +703,10 @@ OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(tanh(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(tanhf(a.convertToFloat()));
default:
return {};
@@ -747,10 +747,10 @@ OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(round(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(roundf(a.convertToFloat()));
default:
return {};
@@ -765,10 +765,10 @@ OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
- switch (a.getSizeInBits(a.getSemantics())) {
- case 64:
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(trunc(a.convertToDouble()));
- case 32:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(truncf(a.convertToFloat()));
default:
return {};
More information about the Mlir-commits
mailing list