[Mlir-commits] [mlir] [mlir][math] Use APFloat::SemanticsToEnum in constant folding (PR #193914)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 24 00:48:39 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

Refactor constant folding in the Math dialect to use APFloat::SemanticsToEnum() instead of getSizeInBits() when checking floating-point semantics. Inferring semantics from bitwidth     is fragile: different formats may share the same bit width but have distinct semantics, leading to incorrect dispatch. SemanticsToEnum() matches on the exact semantics descriptor, making the intent explicit and ensuring correct dispatch.

---
Full diff: https://github.com/llvm/llvm-project/pull/193914.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Math/IR/MathOps.cpp (+90-90) 


``````````diff
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 4c0274ddb18a1..bec95f58260be 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -62,10 +62,10 @@ OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(acos(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(acosf(a.convertToFloat()));
         default:
           return {};
@@ -80,10 +80,10 @@ OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(acosh(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(acoshf(a.convertToFloat()));
         default:
           return {};
@@ -98,10 +98,10 @@ OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(asin(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(asinf(a.convertToFloat()));
         default:
           return {};
@@ -116,10 +116,10 @@ OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(asinh(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(asinhf(a.convertToFloat()));
         default:
           return {};
@@ -134,10 +134,10 @@ OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(atan(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(atanf(a.convertToFloat()));
         default:
           return {};
@@ -152,10 +152,10 @@ OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::AtanhOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(atanh(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(atanhf(a.convertToFloat()));
         default:
           return {};
@@ -174,15 +174,14 @@ OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) {
         if (a.isZero() && b.isZero())
           return llvm::APFloat::getNaN(a.getSemantics());
 
-        if (a.getSizeInBits(a.getSemantics()) == 64 &&
-            b.getSizeInBits(b.getSemantics()) == 64)
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(atan2(a.convertToDouble(), b.convertToDouble()));
-
-        if (a.getSizeInBits(a.getSemantics()) == 32 &&
-            b.getSizeInBits(b.getSemantics()) == 32)
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(atan2f(a.convertToFloat(), b.convertToFloat()));
-
-        return {};
+        default:
+          return {};
+        }
       });
 }
 
@@ -219,10 +218,10 @@ OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(cos(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(cosf(a.convertToFloat()));
         default:
           return {};
@@ -237,10 +236,10 @@ OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(cosh(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(coshf(a.convertToFloat()));
         default:
           return {};
@@ -255,10 +254,10 @@ OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(sin(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(sinf(a.convertToFloat()));
         default:
           return {};
@@ -273,10 +272,10 @@ OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(sinh(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(sinhf(a.convertToFloat()));
         default:
           return {};
@@ -331,10 +330,10 @@ OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(erf(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(erff(a.convertToFloat()));
         default:
           return {};
@@ -424,13 +423,14 @@ OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) {
         if (a.isNegative())
           return {};
 
-        if (a.getSizeInBits(a.getSemantics()) == 64)
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(log(a.convertToDouble()));
-
-        if (a.getSizeInBits(a.getSemantics()) == 32)
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(logf(a.convertToFloat()));
-
-        return {};
+        default:
+          return {};
+        }
       });
 }
 
@@ -444,13 +444,14 @@ OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) {
         if (a.isNegative())
           return {};
 
-        if (a.getSizeInBits(a.getSemantics()) == 64)
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(log2(a.convertToDouble()));
-
-        if (a.getSizeInBits(a.getSemantics()) == 32)
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(log2f(a.convertToFloat()));
-
-        return {};
+        default:
+          return {};
+        }
       });
 }
 
@@ -464,10 +465,10 @@ OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
         if (a.isNegative())
           return {};
 
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(log10(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(log10f(a.convertToFloat()));
         default:
           return {};
@@ -482,12 +483,12 @@ OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
 OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           if ((a + APFloat(1.0)).isNegative())
             return {};
           return APFloat(log1p(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           if ((a + APFloat(1.0f)).isNegative())
             return {};
           return APFloat(log1pf(a.convertToFloat()));
@@ -505,15 +506,14 @@ OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
   return constFoldBinaryOpConditional<FloatAttr>(
       adaptor.getOperands(),
       [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
-        if (a.getSizeInBits(a.getSemantics()) == 64 &&
-            b.getSizeInBits(b.getSemantics()) == 64)
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
-
-        if (a.getSizeInBits(a.getSemantics()) == 32 &&
-            b.getSizeInBits(b.getSemantics()) == 32)
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
-
-        return {};
+        default:
+          return {};
+        }
       });
 }
 
@@ -528,10 +528,10 @@ OpFoldResult math::RsqrtOp::fold(FoldAdaptor adaptor) {
           return {};
 
         APFloat one(a.getSemantics(), 1);
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return one / APFloat(sqrt(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return one / APFloat(sqrtf(a.convertToFloat()));
         default:
           return {};
@@ -549,10 +549,10 @@ OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
         if (a.isNegative())
           return {};
 
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(sqrt(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(sqrtf(a.convertToFloat()));
         default:
           return {};
@@ -567,10 +567,10 @@ OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(exp(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(expf(a.convertToFloat()));
         default:
           return {};
@@ -585,10 +585,10 @@ OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(exp2(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(exp2f(a.convertToFloat()));
         default:
           return {};
@@ -603,10 +603,10 @@ OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
 OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(expm1(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(expm1f(a.convertToFloat()));
         default:
           return {};
@@ -685,10 +685,10 @@ OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(tan(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(tanf(a.convertToFloat()));
         default:
           return {};
@@ -703,10 +703,10 @@ OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(tanh(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(tanhf(a.convertToFloat()));
         default:
           return {};
@@ -747,10 +747,10 @@ OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(round(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(roundf(a.convertToFloat()));
         default:
           return {};
@@ -765,10 +765,10 @@ OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(trunc(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(truncf(a.convertToFloat()));
         default:
           return {};

``````````

</details>


https://github.com/llvm/llvm-project/pull/193914


More information about the Mlir-commits mailing list