[clang] [llvm] [APFloat] Add APFloat support for FP4 data type (PR #95392)

Durgadoss R via cfe-commits cfe-commits at lists.llvm.org
Thu Jun 13 11:12:09 PDT 2024


https://github.com/durga4github updated https://github.com/llvm/llvm-project/pull/95392

>From af17388ffd5096a0c50b62dbd8073f957c052bb1 Mon Sep 17 00:00:00 2001
From: Durgadoss R <durgadossr at nvidia.com>
Date: Wed, 12 Jun 2024 23:55:04 +0530
Subject: [PATCH] [APFloat] Add APFloat support for FP4 data type

This patch adds APFloat type support for the E2M1
FP4 datatype. The definitions for this format are
detailed in section 5.3.3 of the OCP specification,
which can be accessed here:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

Signed-off-by: Durgadoss R <durgadossr at nvidia.com>
---
 clang/lib/AST/MicrosoftMangle.cpp  |   1 +
 llvm/include/llvm/ADT/APFloat.h    |   8 +
 llvm/lib/Support/APFloat.cpp       |  25 ++-
 llvm/unittests/ADT/APFloatTest.cpp | 256 ++++++++++++++++++++++++++++-
 4 files changed, 283 insertions(+), 7 deletions(-)

diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp
index ffc5d2d4cd8fc..a863ec7a529b9 100644
--- a/clang/lib/AST/MicrosoftMangle.cpp
+++ b/clang/lib/AST/MicrosoftMangle.cpp
@@ -901,6 +901,7 @@ void MicrosoftCXXNameMangler::mangleFloat(llvm::APFloat Number) {
   case APFloat::S_FloatTF32:
   case APFloat::S_Float6E3M2FN:
   case APFloat::S_Float6E2M3FN:
+  case APFloat::S_Float4E2M1FN:
     llvm_unreachable("Tried to mangle unexpected APFloat semantics");
   }
 
diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h
index a9bb6cc9999b1..c24eae8da3797 100644
--- a/llvm/include/llvm/ADT/APFloat.h
+++ b/llvm/include/llvm/ADT/APFloat.h
@@ -197,6 +197,10 @@ struct APFloatBase {
     // types, there are no infinity or NaN values. The format is detailed in
     // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
     S_Float6E2M3FN,
+    // 4-bit floating point number with bit layout S1E2M1. Unlike IEEE-754
+    // types, there are no infinity or NaN values. The format is detailed in
+    // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+    S_Float4E2M1FN,
 
     S_x87DoubleExtended,
     S_MaxSemantics = S_x87DoubleExtended,
@@ -219,6 +223,7 @@ struct APFloatBase {
   static const fltSemantics &FloatTF32() LLVM_READNONE;
   static const fltSemantics &Float6E3M2FN() LLVM_READNONE;
   static const fltSemantics &Float6E2M3FN() LLVM_READNONE;
+  static const fltSemantics &Float4E2M1FN() LLVM_READNONE;
   static const fltSemantics &x87DoubleExtended() LLVM_READNONE;
 
   /// A Pseudo fltsemantic used to construct APFloats that cannot conflict with
@@ -639,6 +644,7 @@ class IEEEFloat final : public APFloatBase {
   APInt convertFloatTF32APFloatToAPInt() const;
   APInt convertFloat6E3M2FNAPFloatToAPInt() const;
   APInt convertFloat6E2M3FNAPFloatToAPInt() const;
+  APInt convertFloat4E2M1FNAPFloatToAPInt() const;
   void initFromAPInt(const fltSemantics *Sem, const APInt &api);
   template <const fltSemantics &S> void initFromIEEEAPInt(const APInt &api);
   void initFromHalfAPInt(const APInt &api);
@@ -656,6 +662,7 @@ class IEEEFloat final : public APFloatBase {
   void initFromFloatTF32APInt(const APInt &api);
   void initFromFloat6E3M2FNAPInt(const APInt &api);
   void initFromFloat6E2M3FNAPInt(const APInt &api);
+  void initFromFloat4E2M1FNAPInt(const APInt &api);
 
   void assign(const IEEEFloat &);
   void copySignificand(const IEEEFloat &);
@@ -1067,6 +1074,7 @@ class APFloat : public APFloatBase {
     // Below Semantics do not support {NaN or Inf}
     case APFloat::S_Float6E3M2FN:
     case APFloat::S_Float6E2M3FN:
+    case APFloat::S_Float4E2M1FN:
       return false;
     }
   }
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index 1209bf71a287d..47618bc325951 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -69,8 +69,8 @@ enum class fltNonfiniteBehavior {
   // encodings do not distinguish between signalling and quiet NaN.
   NanOnly,
 
-  // This behavior is present in Float6E3M2FN and Float6E2M3FN types,
-  // which do not support Inf or NaN values.
+  // This behavior is present in Float6E3M2FN, Float6E2M3FN, and
+  // Float4E2M1FN types, which do not support Inf or NaN values.
   FiniteOnly,
 };
 
@@ -147,6 +147,8 @@ static constexpr fltSemantics semFloat6E3M2FN = {
     4, -2, 3, 6, fltNonfiniteBehavior::FiniteOnly};
 static constexpr fltSemantics semFloat6E2M3FN = {
     2, 0, 4, 6, fltNonfiniteBehavior::FiniteOnly};
+static constexpr fltSemantics semFloat4E2M1FN = {
+    2, 0, 2, 4, fltNonfiniteBehavior::FiniteOnly};
 static constexpr fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80};
 static constexpr fltSemantics semBogus = {0, 0, 0, 0};
 
@@ -218,6 +220,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) {
     return Float6E3M2FN();
   case S_Float6E2M3FN:
     return Float6E2M3FN();
+  case S_Float4E2M1FN:
+    return Float4E2M1FN();
   case S_x87DoubleExtended:
     return x87DoubleExtended();
   }
@@ -254,6 +258,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
     return S_Float6E3M2FN;
   else if (&Sem == &llvm::APFloat::Float6E2M3FN())
     return S_Float6E2M3FN;
+  else if (&Sem == &llvm::APFloat::Float4E2M1FN())
+    return S_Float4E2M1FN;
   else if (&Sem == &llvm::APFloat::x87DoubleExtended())
     return S_x87DoubleExtended;
   else
@@ -278,6 +284,7 @@ const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
 const fltSemantics &APFloatBase::FloatTF32() { return semFloatTF32; }
 const fltSemantics &APFloatBase::Float6E3M2FN() { return semFloat6E3M2FN; }
 const fltSemantics &APFloatBase::Float6E2M3FN() { return semFloat6E2M3FN; }
+const fltSemantics &APFloatBase::Float4E2M1FN() { return semFloat4E2M1FN; }
 const fltSemantics &APFloatBase::x87DoubleExtended() {
   return semX87DoubleExtended;
 }
@@ -3640,6 +3647,11 @@ APInt IEEEFloat::convertFloat6E2M3FNAPFloatToAPInt() const {
   return convertIEEEFloatToAPInt<semFloat6E2M3FN>();
 }
 
+APInt IEEEFloat::convertFloat4E2M1FNAPFloatToAPInt() const {
+  assert(partCount() == 1);
+  return convertIEEEFloatToAPInt<semFloat4E2M1FN>();
+}
+
 // This function creates an APInt that is just a bit map of the floating
 // point constant as it would appear in memory.  It is not a conversion,
 // and treating the result as a normal integer is unlikely to be useful.
@@ -3687,6 +3699,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
   if (semantics == (const llvm::fltSemantics *)&semFloat6E2M3FN)
     return convertFloat6E2M3FNAPFloatToAPInt();
 
+  if (semantics == (const llvm::fltSemantics *)&semFloat4E2M1FN)
+    return convertFloat4E2M1FNAPFloatToAPInt();
+
   assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended &&
          "unknown format!");
   return convertF80LongDoubleAPFloatToAPInt();
@@ -3911,6 +3926,10 @@ void IEEEFloat::initFromFloat6E2M3FNAPInt(const APInt &api) {
   initFromIEEEAPInt<semFloat6E2M3FN>(api);
 }
 
+void IEEEFloat::initFromFloat4E2M1FNAPInt(const APInt &api) {
+  initFromIEEEAPInt<semFloat4E2M1FN>(api);
+}
+
 /// Treat api as containing the bits of a floating point number.
 void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
   assert(api.getBitWidth() == Sem->sizeInBits);
@@ -3944,6 +3963,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
     return initFromFloat6E3M2FNAPInt(api);
   if (Sem == &semFloat6E2M3FN)
     return initFromFloat6E2M3FNAPInt(api);
+  if (Sem == &semFloat4E2M1FN)
+    return initFromFloat4E2M1FNAPInt(api);
 
   llvm_unreachable(nullptr);
 }
diff --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp
index 7007d944801a7..f6af4b0e5f651 100644
--- a/llvm/unittests/ADT/APFloatTest.cpp
+++ b/llvm/unittests/ADT/APFloatTest.cpp
@@ -1828,6 +1828,7 @@ TEST(APFloatTest, getLargest) {
   EXPECT_EQ(28, APFloat::getLargest(APFloat::Float6E3M2FN()).convertToDouble());
   EXPECT_EQ(7.5,
             APFloat::getLargest(APFloat::Float6E2M3FN()).convertToDouble());
+  EXPECT_EQ(6, APFloat::getLargest(APFloat::Float4E2M1FN()).convertToDouble());
 }
 
 TEST(APFloatTest, getSmallest) {
@@ -1900,6 +1901,13 @@ TEST(APFloatTest, getSmallest) {
   EXPECT_TRUE(test.isFiniteNonZero());
   EXPECT_TRUE(test.isDenormal());
   EXPECT_TRUE(test.bitwiseIsEqual(expected));
+
+  test = APFloat::getSmallest(APFloat::Float4E2M1FN(), false);
+  expected = APFloat(APFloat::Float4E2M1FN(), "0x0.8p0");
+  EXPECT_FALSE(test.isNegative());
+  EXPECT_TRUE(test.isFiniteNonZero());
+  EXPECT_TRUE(test.isDenormal());
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
 }
 
 TEST(APFloatTest, getSmallestNormalized) {
@@ -1984,6 +1992,14 @@ TEST(APFloatTest, getSmallestNormalized) {
   EXPECT_TRUE(test.isSmallestNormalized());
   test = APFloat::getSmallestNormalized(APFloat::Float6E3M2FN(), false);
   expected = APFloat(APFloat::Float6E3M2FN(), "0x1p-2");
+
+  test = APFloat::getSmallestNormalized(APFloat::Float4E2M1FN(), false);
+  expected = APFloat(APFloat::Float4E2M1FN(), "0x1p0");
+  EXPECT_FALSE(test.isNegative());
+  EXPECT_TRUE(test.isFiniteNonZero());
+  EXPECT_FALSE(test.isDenormal());
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+  EXPECT_TRUE(test.isSmallestNormalized());
   EXPECT_FALSE(test.isNegative());
   EXPECT_TRUE(test.isFiniteNonZero());
   EXPECT_FALSE(test.isDenormal());
@@ -2034,7 +2050,9 @@ TEST(APFloatTest, getZero) {
       {&APFloat::Float6E3M2FN(), false, true, {0, 0}, 1},
       {&APFloat::Float6E3M2FN(), true, true, {0x20ULL, 0}, 1},
       {&APFloat::Float6E2M3FN(), false, true, {0, 0}, 1},
-      {&APFloat::Float6E2M3FN(), true, true, {0x20ULL, 0}, 1}};
+      {&APFloat::Float6E2M3FN(), true, true, {0x20ULL, 0}, 1},
+      {&APFloat::Float4E2M1FN(), false, true, {0, 0}, 1},
+      {&APFloat::Float4E2M1FN(), true, true, {0x8ULL, 0}, 1}};
   const unsigned NumGetZeroTests = std::size(GetZeroTest);
   for (unsigned i = 0; i < NumGetZeroTests; ++i) {
     APFloat test = APFloat::getZero(*GetZeroTest[i].semantics,
@@ -5283,6 +5301,89 @@ TEST(APFloatTest, Float6ExhaustivePair) {
   }
 }
 
+TEST(APFloatTest, Float4ExhaustivePair) {
+  // Test each pair of 4-bit floats with non-standard semantics
+  for (APFloat::Semantics Sem : {APFloat::S_Float4E2M1FN}) {
+    const llvm::fltSemantics &S = APFloat::EnumToSemantics(Sem);
+    for (int i = 0; i < 16; i++) {
+      for (int j = 0; j < 16; j++) {
+        SCOPED_TRACE("sem=" + std::to_string(Sem) + ",i=" + std::to_string(i) +
+                     ",j=" + std::to_string(j));
+        APFloat x(S, APInt(4, i));
+        APFloat y(S, APInt(4, j));
+
+        bool losesInfo;
+        APFloat x16 = x;
+        x16.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven,
+                    &losesInfo);
+        EXPECT_FALSE(losesInfo);
+        APFloat y16 = y;
+        y16.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven,
+                    &losesInfo);
+        EXPECT_FALSE(losesInfo);
+
+        // Add
+        APFloat z = x;
+        z.add(y, APFloat::rmNearestTiesToEven);
+        APFloat z16 = x16;
+        z16.add(y16, APFloat::rmNearestTiesToEven);
+        z16.convert(S, APFloat::rmNearestTiesToEven, &losesInfo);
+        EXPECT_TRUE(z.bitwiseIsEqual(z16))
+            << "sem=" << Sem << ", i=" << i << ", j=" << j;
+
+        // Subtract
+        z = x;
+        z.subtract(y, APFloat::rmNearestTiesToEven);
+        z16 = x16;
+        z16.subtract(y16, APFloat::rmNearestTiesToEven);
+        z16.convert(S, APFloat::rmNearestTiesToEven, &losesInfo);
+        EXPECT_TRUE(z.bitwiseIsEqual(z16))
+            << "sem=" << Sem << ", i=" << i << ", j=" << j;
+
+        // Multiply
+        z = x;
+        z.multiply(y, APFloat::rmNearestTiesToEven);
+        z16 = x16;
+        z16.multiply(y16, APFloat::rmNearestTiesToEven);
+        z16.convert(S, APFloat::rmNearestTiesToEven, &losesInfo);
+        EXPECT_TRUE(z.bitwiseIsEqual(z16))
+            << "sem=" << Sem << ", i=" << i << ", j=" << j;
+
+        // Skip divide by 0
+        if (j == 0 || j == 8)
+          continue;
+
+        // Divide
+        z = x;
+        z.divide(y, APFloat::rmNearestTiesToEven);
+        z16 = x16;
+        z16.divide(y16, APFloat::rmNearestTiesToEven);
+        z16.convert(S, APFloat::rmNearestTiesToEven, &losesInfo);
+        EXPECT_TRUE(z.bitwiseIsEqual(z16))
+            << "sem=" << Sem << ", i=" << i << ", j=" << j;
+
+        // Mod
+        z = x;
+        z.mod(y);
+        z16 = x16;
+        z16.mod(y16);
+        z16.convert(S, APFloat::rmNearestTiesToEven, &losesInfo);
+        EXPECT_TRUE(z.bitwiseIsEqual(z16))
+            << "sem=" << Sem << ", i=" << i << ", j=" << j;
+
+        // Remainder
+        z = x;
+        z.remainder(y);
+        z16 = x16;
+        z16.remainder(y16);
+        z16.convert(S, APFloat::rmNearestTiesToEven, &losesInfo);
+        EXPECT_TRUE(z.bitwiseIsEqual(z16))
+            << "sem=" << Sem << ", i=" << i << ", j=" << j;
+      }
+    }
+  }
+}
+
 TEST(APFloatTest, ConvertE4M3FNToE5M2) {
   bool losesInfo;
   APFloat test(APFloat::Float8E4M3FN(), "1.0");
@@ -6743,7 +6844,7 @@ TEST(APFloatTest, getExactLog2) {
     EXPECT_EQ(INT_MIN, APFloat(Semantics, "3.0").getExactLog2Abs());
     EXPECT_EQ(INT_MIN, APFloat(Semantics, "-3.0").getExactLog2Abs());
 
-    if (I == APFloat::S_Float6E2M3FN) {
+    if (I == APFloat::S_Float6E2M3FN || I == APFloat::S_Float4E2M1FN) {
       EXPECT_EQ(2, APFloat(Semantics, "4.0").getExactLog2());
       EXPECT_EQ(INT_MIN, APFloat(Semantics, "-4.0").getExactLog2());
       EXPECT_EQ(2, APFloat(Semantics, "4.0").getExactLog2Abs());
@@ -6831,6 +6932,25 @@ TEST(APFloatTest, Float6E2M3FNFromString) {
   EXPECT_TRUE(APFloat(APFloat::Float6E2M3FN(), "-0").isNegZero());
 }
 
+TEST(APFloatTest, Float4E2M1FNFromString) {
+  // Exactly representable
+  EXPECT_EQ(6, APFloat(APFloat::Float4E2M1FN(), "6").convertToDouble());
+  // Round down to maximum value
+  EXPECT_EQ(6, APFloat(APFloat::Float4E2M1FN(), "32").convertToDouble());
+
+#ifdef GTEST_HAS_DEATH_TEST
+#ifndef NDEBUG
+  EXPECT_DEATH(APFloat(APFloat::Float4E2M1FN(), "inf"),
+               "This floating point format does not support Inf");
+  EXPECT_DEATH(APFloat(APFloat::Float4E2M1FN(), "nan"),
+               "This floating point format does not support NaN");
+#endif
+#endif
+
+  EXPECT_TRUE(APFloat(APFloat::Float4E2M1FN(), "0").isPosZero());
+  EXPECT_TRUE(APFloat(APFloat::Float4E2M1FN(), "-0").isNegZero());
+}
+
 TEST(APFloatTest, ConvertE3M2FToE2M3F) {
   bool losesInfo;
   APFloat test(APFloat::Float6E3M2FN(), "1.0");
@@ -6848,7 +6968,6 @@ TEST(APFloatTest, ConvertE3M2FToE2M3F) {
   EXPECT_EQ(status, APFloat::opOK);
 
   // Test overflow
-  losesInfo = false;
   test = APFloat(APFloat::Float6E3M2FN(), "28");
   status = test.convert(APFloat::Float6E2M3FN(), APFloat::rmNearestTiesToEven,
                         &losesInfo);
@@ -6865,7 +6984,6 @@ TEST(APFloatTest, ConvertE3M2FToE2M3F) {
   EXPECT_EQ(status, APFloat::opUnderflow | APFloat::opInexact);
 
   // Testing inexact rounding to denormal number
-  losesInfo = false;
   test = APFloat(APFloat::Float6E3M2FN(), "0.1875");
   status = test.convert(APFloat::Float6E2M3FN(), APFloat::rmNearestTiesToEven,
                         &losesInfo);
@@ -6898,7 +7016,6 @@ TEST(APFloatTest, ConvertE2M3FToE3M2F) {
   EXPECT_EQ(status, APFloat::opOK);
 
   // Test inexact rounding
-  losesInfo = false;
   test = APFloat(APFloat::Float6E2M3FN(), "7.5");
   status = test.convert(APFloat::Float6E3M2FN(), APFloat::rmNearestTiesToEven,
                         &losesInfo);
@@ -6907,6 +7024,40 @@ TEST(APFloatTest, ConvertE2M3FToE3M2F) {
   EXPECT_EQ(status, APFloat::opInexact);
 }
 
+TEST(APFloatTest, ConvertDoubleToE2M1F) {
+  bool losesInfo;
+  APFloat test(APFloat::IEEEdouble(), "1.0");
+  APFloat::opStatus status = test.convert(
+      APFloat::Float4E2M1FN(), APFloat::rmNearestTiesToEven, &losesInfo);
+  EXPECT_EQ(1.0, test.convertToDouble());
+  EXPECT_FALSE(losesInfo);
+  EXPECT_EQ(status, APFloat::opOK);
+
+  test = APFloat(APFloat::IEEEdouble(), "0.0");
+  status = test.convert(APFloat::Float4E2M1FN(), APFloat::rmNearestTiesToEven,
+                        &losesInfo);
+  EXPECT_EQ(0.0f, test.convertToDouble());
+  EXPECT_FALSE(losesInfo);
+  EXPECT_EQ(status, APFloat::opOK);
+
+  // Test overflow
+  test = APFloat(APFloat::IEEEdouble(), "8");
+  status = test.convert(APFloat::Float4E2M1FN(), APFloat::rmNearestTiesToEven,
+                        &losesInfo);
+  EXPECT_EQ(6, test.convertToDouble());
+  EXPECT_TRUE(losesInfo);
+  EXPECT_EQ(status, APFloat::opInexact);
+
+  // Test underflow
+  test = APFloat(APFloat::IEEEdouble(), "0.25");
+  status = test.convert(APFloat::Float4E2M1FN(), APFloat::rmNearestTiesToEven,
+                        &losesInfo);
+  EXPECT_EQ(0., test.convertToDouble());
+  EXPECT_TRUE(losesInfo);
+  EXPECT_FALSE(test.isDenormal());
+  EXPECT_EQ(status, APFloat::opUnderflow | APFloat::opInexact);
+}
+
 TEST(APFloatTest, Float6E3M2FNNext) {
   APFloat test(APFloat::Float6E3M2FN(), APFloat::uninitialized);
   APFloat expected(APFloat::Float6E3M2FN(), APFloat::uninitialized);
@@ -6983,6 +7134,44 @@ TEST(APFloatTest, Float6E2M3FNNext) {
   EXPECT_TRUE(test.bitwiseIsEqual(expected));
 }
 
+TEST(APFloatTest, Float4E2M1FNNext) {
+  APFloat test(APFloat::Float4E2M1FN(), APFloat::uninitialized);
+  APFloat expected(APFloat::Float4E2M1FN(), APFloat::uninitialized);
+
+  // 1. NextUp of largest bit pattern is the same
+  test = APFloat::getLargest(APFloat::Float4E2M1FN());
+  expected = APFloat::getLargest(APFloat::Float4E2M1FN());
+  EXPECT_EQ(test.next(false), APFloat::opOK);
+  EXPECT_FALSE(test.isInfinity());
+  EXPECT_FALSE(test.isZero());
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+
+  // 2. NextUp of smallest negative denormal is -0
+  test = APFloat::getSmallest(APFloat::Float4E2M1FN(), true);
+  expected = APFloat::getZero(APFloat::Float4E2M1FN(), true);
+  EXPECT_EQ(test.next(false), APFloat::opOK);
+  EXPECT_TRUE(test.isNegZero());
+  EXPECT_FALSE(test.isPosZero());
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+
+  // 3. nextDown of negative of largest value is the same
+  test = APFloat::getLargest(APFloat::Float4E2M1FN(), true);
+  expected = test;
+  EXPECT_EQ(test.next(true), APFloat::opOK);
+  EXPECT_FALSE(test.isInfinity());
+  EXPECT_FALSE(test.isZero());
+  EXPECT_FALSE(test.isNaN());
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+
+  // 4. nextDown of +0 is smallest negative denormal
+  test = APFloat::getZero(APFloat::Float4E2M1FN(), false);
+  expected = APFloat::getSmallest(APFloat::Float4E2M1FN(), true);
+  EXPECT_EQ(test.next(true), APFloat::opOK);
+  EXPECT_FALSE(test.isZero());
+  EXPECT_TRUE(test.isDenormal());
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+}
+
 #ifdef GTEST_HAS_DEATH_TEST
 #ifndef NDEBUG
 TEST(APFloatTest, Float6E3M2FNGetInfNaN) {
@@ -6998,6 +7187,13 @@ TEST(APFloatTest, Float6E2M3FNGetInfNaN) {
   EXPECT_DEATH(APFloat::getNaN(APFloat::Float6E2M3FN()),
                "This floating point format does not support NaN");
 }
+
+TEST(APFloatTest, Float4E2M1FNGetInfNaN) {
+  EXPECT_DEATH(APFloat::getInf(APFloat::Float4E2M1FN()),
+               "This floating point format does not support Inf");
+  EXPECT_DEATH(APFloat::getNaN(APFloat::Float4E2M1FN()),
+               "This floating point format does not support NaN");
+}
 #endif
 #endif
 
@@ -7043,6 +7239,27 @@ TEST(APFloatTest, Float6E2M3FNToDouble) {
   EXPECT_EQ(0x0.2p0, SmallestDenorm.convertToDouble());
 }
 
+TEST(APFloatTest, Float4E2M1FNToDouble) {
+  APFloat One(APFloat::Float4E2M1FN(), "1.0");
+  EXPECT_EQ(1.0, One.convertToDouble());
+  APFloat Two(APFloat::Float4E2M1FN(), "2.0");
+  EXPECT_EQ(2.0, Two.convertToDouble());
+  APFloat PosLargest = APFloat::getLargest(APFloat::Float4E2M1FN(), false);
+  EXPECT_EQ(6, PosLargest.convertToDouble());
+  APFloat NegLargest = APFloat::getLargest(APFloat::Float4E2M1FN(), true);
+  EXPECT_EQ(-6, NegLargest.convertToDouble());
+  APFloat PosSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float4E2M1FN(), false);
+  EXPECT_EQ(0x1p0, PosSmallest.convertToDouble());
+  APFloat NegSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float4E2M1FN(), true);
+  EXPECT_EQ(-0x1p0, NegSmallest.convertToDouble());
+
+  APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float4E2M1FN(), false);
+  EXPECT_TRUE(SmallestDenorm.isDenormal());
+  EXPECT_EQ(0x0.8p0, SmallestDenorm.convertToDouble());
+}
+
 TEST(APFloatTest, Float6E3M2FNToFloat) {
   APFloat PosZero = APFloat::getZero(APFloat::Float6E3M2FN());
   APFloat PosZeroToFloat(PosZero.convertToFloat());
@@ -7100,4 +7317,33 @@ TEST(APFloatTest, Float6E2M3FNToFloat) {
   EXPECT_TRUE(SmallestDenorm.isDenormal());
   EXPECT_EQ(0x0.2p0, SmallestDenorm.convertToFloat());
 }
+
+TEST(APFloatTest, Float4E2M1FNToFloat) {
+  APFloat PosZero = APFloat::getZero(APFloat::Float4E2M1FN());
+  APFloat PosZeroToFloat(PosZero.convertToFloat());
+  EXPECT_TRUE(PosZeroToFloat.isPosZero());
+  APFloat NegZero = APFloat::getZero(APFloat::Float4E2M1FN(), true);
+  APFloat NegZeroToFloat(NegZero.convertToFloat());
+  EXPECT_TRUE(NegZeroToFloat.isNegZero());
+
+  APFloat One(APFloat::Float4E2M1FN(), "1.0");
+  EXPECT_EQ(1.0F, One.convertToFloat());
+  APFloat Two(APFloat::Float4E2M1FN(), "2.0");
+  EXPECT_EQ(2.0F, Two.convertToFloat());
+
+  APFloat PosLargest = APFloat::getLargest(APFloat::Float4E2M1FN(), false);
+  EXPECT_EQ(6, PosLargest.convertToFloat());
+  APFloat NegLargest = APFloat::getLargest(APFloat::Float4E2M1FN(), true);
+  EXPECT_EQ(-6, NegLargest.convertToFloat());
+  APFloat PosSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float4E2M1FN(), false);
+  EXPECT_EQ(0x1p0, PosSmallest.convertToFloat());
+  APFloat NegSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float4E2M1FN(), true);
+  EXPECT_EQ(-0x1p0, NegSmallest.convertToFloat());
+
+  APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float4E2M1FN(), false);
+  EXPECT_TRUE(SmallestDenorm.isDenormal());
+  EXPECT_EQ(0x0.8p0, SmallestDenorm.convertToFloat());
+}
 } // namespace



More information about the cfe-commits mailing list