[Mlir-commits] [mlir] 170f030 - [mlir][math] Use APFloat::SemanticsToEnum in constant folding (#193914)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 24 01:56:35 PDT 2026


Author: Longsheng Mou
Date: 2026-04-24T16:56:30+08:00
New Revision: 170f030c22c52bc5d840cd8f76756c9fba5ee816

URL: https://github.com/llvm/llvm-project/commit/170f030c22c52bc5d840cd8f76756c9fba5ee816
DIFF: https://github.com/llvm/llvm-project/commit/170f030c22c52bc5d840cd8f76756c9fba5ee816.diff

LOG: [mlir][math] Use APFloat::SemanticsToEnum in constant folding (#193914)

Refactor constant folding in the Math dialect to use APFloat::SemanticsToEnum() instead of getSizeInBits() when checking
floating-point semantics. Inferring semantics from bitwidth is fragile: different formats may share the same bit width but have distinct semantics, leading to incorrect dispatch. SemanticsToEnum() matches on the exact semantics descriptor, making the intent explicit and ensuring correct dispatch.

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/IR/MathOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 4c0274ddb18a1..bec95f58260be 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -62,10 +62,10 @@ OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(acos(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(acosf(a.convertToFloat()));
         default:
           return {};
@@ -80,10 +80,10 @@ OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(acosh(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(acoshf(a.convertToFloat()));
         default:
           return {};
@@ -98,10 +98,10 @@ OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(asin(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(asinf(a.convertToFloat()));
         default:
           return {};
@@ -116,10 +116,10 @@ OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(asinh(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(asinhf(a.convertToFloat()));
         default:
           return {};
@@ -134,10 +134,10 @@ OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(atan(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(atanf(a.convertToFloat()));
         default:
           return {};
@@ -152,10 +152,10 @@ OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::AtanhOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(atanh(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(atanhf(a.convertToFloat()));
         default:
           return {};
@@ -174,15 +174,14 @@ OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) {
         if (a.isZero() && b.isZero())
           return llvm::APFloat::getNaN(a.getSemantics());
 
-        if (a.getSizeInBits(a.getSemantics()) == 64 &&
-            b.getSizeInBits(b.getSemantics()) == 64)
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(atan2(a.convertToDouble(), b.convertToDouble()));
-
-        if (a.getSizeInBits(a.getSemantics()) == 32 &&
-            b.getSizeInBits(b.getSemantics()) == 32)
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(atan2f(a.convertToFloat(), b.convertToFloat()));
-
-        return {};
+        default:
+          return {};
+        }
       });
 }
 
@@ -219,10 +218,10 @@ OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(cos(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(cosf(a.convertToFloat()));
         default:
           return {};
@@ -237,10 +236,10 @@ OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(cosh(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(coshf(a.convertToFloat()));
         default:
           return {};
@@ -255,10 +254,10 @@ OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(sin(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(sinf(a.convertToFloat()));
         default:
           return {};
@@ -273,10 +272,10 @@ OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(sinh(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(sinhf(a.convertToFloat()));
         default:
           return {};
@@ -331,10 +330,10 @@ OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(erf(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(erff(a.convertToFloat()));
         default:
           return {};
@@ -424,13 +423,14 @@ OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) {
         if (a.isNegative())
           return {};
 
-        if (a.getSizeInBits(a.getSemantics()) == 64)
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(log(a.convertToDouble()));
-
-        if (a.getSizeInBits(a.getSemantics()) == 32)
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(logf(a.convertToFloat()));
-
-        return {};
+        default:
+          return {};
+        }
       });
 }
 
@@ -444,13 +444,14 @@ OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) {
         if (a.isNegative())
           return {};
 
-        if (a.getSizeInBits(a.getSemantics()) == 64)
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(log2(a.convertToDouble()));
-
-        if (a.getSizeInBits(a.getSemantics()) == 32)
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(log2f(a.convertToFloat()));
-
-        return {};
+        default:
+          return {};
+        }
       });
 }
 
@@ -464,10 +465,10 @@ OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
         if (a.isNegative())
           return {};
 
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(log10(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(log10f(a.convertToFloat()));
         default:
           return {};
@@ -482,12 +483,12 @@ OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
 OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           if ((a + APFloat(1.0)).isNegative())
             return {};
           return APFloat(log1p(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           if ((a + APFloat(1.0f)).isNegative())
             return {};
           return APFloat(log1pf(a.convertToFloat()));
@@ -505,15 +506,14 @@ OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
   return constFoldBinaryOpConditional<FloatAttr>(
       adaptor.getOperands(),
       [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
-        if (a.getSizeInBits(a.getSemantics()) == 64 &&
-            b.getSizeInBits(b.getSemantics()) == 64)
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
-
-        if (a.getSizeInBits(a.getSemantics()) == 32 &&
-            b.getSizeInBits(b.getSemantics()) == 32)
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
-
-        return {};
+        default:
+          return {};
+        }
       });
 }
 
@@ -528,10 +528,10 @@ OpFoldResult math::RsqrtOp::fold(FoldAdaptor adaptor) {
           return {};
 
         APFloat one(a.getSemantics(), 1);
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return one / APFloat(sqrt(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return one / APFloat(sqrtf(a.convertToFloat()));
         default:
           return {};
@@ -549,10 +549,10 @@ OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
         if (a.isNegative())
           return {};
 
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(sqrt(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(sqrtf(a.convertToFloat()));
         default:
           return {};
@@ -567,10 +567,10 @@ OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(exp(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(expf(a.convertToFloat()));
         default:
           return {};
@@ -585,10 +585,10 @@ OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(exp2(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(exp2f(a.convertToFloat()));
         default:
           return {};
@@ -603,10 +603,10 @@ OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
 OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(expm1(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(expm1f(a.convertToFloat()));
         default:
           return {};
@@ -685,10 +685,10 @@ OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(tan(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(tanf(a.convertToFloat()));
         default:
           return {};
@@ -703,10 +703,10 @@ OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(tanh(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(tanhf(a.convertToFloat()));
         default:
           return {};
@@ -747,10 +747,10 @@ OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(round(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(roundf(a.convertToFloat()));
         default:
           return {};
@@ -765,10 +765,10 @@ OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
 OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
   return constFoldUnaryOpConditional<FloatAttr>(
       adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
-        switch (a.getSizeInBits(a.getSemantics())) {
-        case 64:
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
           return APFloat(trunc(a.convertToDouble()));
-        case 32:
+        case APFloat::Semantics::S_IEEEsingle:
           return APFloat(truncf(a.convertToFloat()));
         default:
           return {};


        


More information about the Mlir-commits mailing list