[clang] 55c2211 - [APFloat] Add APFloat semantic support for TF32

Mehdi Amini via cfe-commits cfe-commits at lists.llvm.org
Fri Jun 23 01:55:43 PDT 2023


Author: Jeremy Furtek
Date: 2023-06-23T10:54:49+02:00
New Revision: 55c2211a233e11179048cf58778f40e5a62f444a

URL: https://github.com/llvm/llvm-project/commit/55c2211a233e11179048cf58778f40e5a62f444a
DIFF: https://github.com/llvm/llvm-project/commit/55c2211a233e11179048cf58778f40e5a62f444a.diff

LOG: [APFloat] Add APFloat semantic support for TF32

This diff adds APFloat support for a semantic that matches the TF32 data type
used by some accelerators (most notably GPUs from both NVIDIA and AMD).

For more information on the TF32 data type, see https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/.
Some intrinsics that support the TF32 data type were added in https://reviews.llvm.org/D122044.

For some discussion on supporting common semantics in `APFloat`, see similar
efforts for 8-bit formats at https://reviews.llvm.org/D146441, as well as
https://discourse.llvm.org/t/rfc-adding-the-amd-graphcore-maybe-others-float8-formats-to-apfloat/67969.

A subsequent diff will extend MLIR to use this data type. (Those changes are
not part of this diff to simplify the review process.)

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D151923

Added: 
    

Modified: 
    clang/lib/AST/MicrosoftMangle.cpp
    llvm/include/llvm/ADT/APFloat.h
    llvm/lib/Support/APFloat.cpp
    llvm/unittests/ADT/APFloatTest.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp
index 9fede7bbad323..430a57d7b4ec0 100644
--- a/clang/lib/AST/MicrosoftMangle.cpp
+++ b/clang/lib/AST/MicrosoftMangle.cpp
@@ -898,6 +898,7 @@ void MicrosoftCXXNameMangler::mangleFloat(llvm::APFloat Number) {
   case APFloat::S_Float8E5M2FNUZ:
   case APFloat::S_Float8E4M3FNUZ:
   case APFloat::S_Float8E4M3B11FNUZ:
+  case APFloat::S_FloatTF32:
     llvm_unreachable("Tried to mangle unexpected APFloat semantics");
   }
 

diff  --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h
index 1875706362f79..64caa5a765456 100644
--- a/llvm/include/llvm/ADT/APFloat.h
+++ b/llvm/include/llvm/ADT/APFloat.h
@@ -184,6 +184,10 @@ struct APFloatBase {
     // This format's exponent bias is 11, instead of the 7 (2 ** (4 - 1) - 1)
     // that IEEE precedent would imply.
     S_Float8E4M3B11FNUZ,
+    // Floating point number that occupies 32 bits or less of storage, providing
+    // improved range compared to half (16-bit) formats, at (potentially)
+    // greater throughput than single precision (32-bit) formats.
+    S_FloatTF32,
 
     S_x87DoubleExtended,
     S_MaxSemantics = S_x87DoubleExtended,
@@ -203,6 +207,7 @@ struct APFloatBase {
   static const fltSemantics &Float8E4M3FN() LLVM_READNONE;
   static const fltSemantics &Float8E4M3FNUZ() LLVM_READNONE;
   static const fltSemantics &Float8E4M3B11FNUZ() LLVM_READNONE;
+  static const fltSemantics &FloatTF32() LLVM_READNONE;
   static const fltSemantics &x87DoubleExtended() LLVM_READNONE;
 
   /// A Pseudo fltsemantic used to construct APFloats that cannot conflict with
@@ -605,6 +610,7 @@ class IEEEFloat final : public APFloatBase {
   APInt convertFloat8E4M3FNAPFloatToAPInt() const;
   APInt convertFloat8E4M3FNUZAPFloatToAPInt() const;
   APInt convertFloat8E4M3B11FNUZAPFloatToAPInt() const;
+  APInt convertFloatTF32APFloatToAPInt() const;
   void initFromAPInt(const fltSemantics *Sem, const APInt &api);
   template <const fltSemantics &S> void initFromIEEEAPInt(const APInt &api);
   void initFromHalfAPInt(const APInt &api);
@@ -619,6 +625,7 @@ class IEEEFloat final : public APFloatBase {
   void initFromFloat8E4M3FNAPInt(const APInt &api);
   void initFromFloat8E4M3FNUZAPInt(const APInt &api);
   void initFromFloat8E4M3B11FNUZAPInt(const APInt &api);
+  void initFromFloatTF32APInt(const APInt &api);
 
   void assign(const IEEEFloat &);
   void copySignificand(const IEEEFloat &);

diff  --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index c882c08b256e7..4a73739b5282a 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -138,6 +138,7 @@ static constexpr fltSemantics semFloat8E4M3FNUZ = {
     7, -7, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
 static constexpr fltSemantics semFloat8E4M3B11FNUZ = {
     4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
+static constexpr fltSemantics semFloatTF32 = {127, -126, 11, 19};
 static constexpr fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80};
 static constexpr fltSemantics semBogus = {0, 0, 0, 0};
 
@@ -203,6 +204,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) {
     return Float8E4M3FNUZ();
   case S_Float8E4M3B11FNUZ:
     return Float8E4M3B11FNUZ();
+  case S_FloatTF32:
+    return FloatTF32();
   case S_x87DoubleExtended:
     return x87DoubleExtended();
   }
@@ -233,6 +236,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
     return S_Float8E4M3FNUZ;
   else if (&Sem == &llvm::APFloat::Float8E4M3B11FNUZ())
     return S_Float8E4M3B11FNUZ;
+  else if (&Sem == &llvm::APFloat::FloatTF32())
+    return S_FloatTF32;
   else if (&Sem == &llvm::APFloat::x87DoubleExtended())
     return S_x87DoubleExtended;
   else
@@ -254,6 +259,7 @@ const fltSemantics &APFloatBase::Float8E4M3FNUZ() { return semFloat8E4M3FNUZ; }
 const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
   return semFloat8E4M3B11FNUZ;
 }
+const fltSemantics &APFloatBase::FloatTF32() { return semFloatTF32; }
 const fltSemantics &APFloatBase::x87DoubleExtended() {
   return semX87DoubleExtended;
 }
@@ -3599,6 +3605,11 @@ APInt IEEEFloat::convertFloat8E4M3B11FNUZAPFloatToAPInt() const {
   return convertIEEEFloatToAPInt<semFloat8E4M3B11FNUZ>();
 }
 
+APInt IEEEFloat::convertFloatTF32APFloatToAPInt() const {
+  assert(partCount() == 1);
+  return convertIEEEFloatToAPInt<semFloatTF32>();
+}
+
 // This function creates an APInt that is just a bit map of the floating
 // point constant as it would appear in memory.  It is not a conversion,
 // and treating the result as a normal integer is unlikely to be useful.
@@ -3637,6 +3648,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
   if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3B11FNUZ)
     return convertFloat8E4M3B11FNUZAPFloatToAPInt();
 
+  if (semantics == (const llvm::fltSemantics *)&semFloatTF32)
+    return convertFloatTF32APFloatToAPInt();
+
   assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended &&
          "unknown format!");
   return convertF80LongDoubleAPFloatToAPInt();
@@ -3840,6 +3854,10 @@ void IEEEFloat::initFromFloat8E4M3B11FNUZAPInt(const APInt &api) {
   initFromIEEEAPInt<semFloat8E4M3B11FNUZ>(api);
 }
 
+void IEEEFloat::initFromFloatTF32APInt(const APInt &api) {
+  initFromIEEEAPInt<semFloatTF32>(api);
+}
+
 /// Treat api as containing the bits of a floating point number.
 void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
   assert(api.getBitWidth() == Sem->sizeInBits);
@@ -3867,6 +3885,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
     return initFromFloat8E4M3FNUZAPInt(api);
   if (Sem == &semFloat8E4M3B11FNUZ)
     return initFromFloat8E4M3B11FNUZAPInt(api);
+  if (Sem == &semFloatTF32)
+    return initFromFloatTF32APInt(api);
 
   llvm_unreachable(nullptr);
 }

diff  --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp
index 2d79d34c3104a..c8a5c67feaf80 100644
--- a/llvm/unittests/ADT/APFloatTest.cpp
+++ b/llvm/unittests/ADT/APFloatTest.cpp
@@ -682,6 +682,26 @@ TEST(APFloatTest, Denormal) {
     EXPECT_TRUE(T.isDenormal());
     EXPECT_EQ(fcPosSubnormal, T.classify());
   }
+
+  // Test TF32
+  {
+    const char *MinNormalStr = "1.17549435082228750797e-38";
+    EXPECT_FALSE(APFloat(APFloat::FloatTF32(), MinNormalStr).isDenormal());
+    EXPECT_FALSE(APFloat(APFloat::FloatTF32(), 0).isDenormal());
+
+    APFloat Val2(APFloat::FloatTF32(), 2);
+    APFloat T(APFloat::FloatTF32(), MinNormalStr);
+    T.divide(Val2, rdmd);
+    EXPECT_TRUE(T.isDenormal());
+    EXPECT_EQ(fcPosSubnormal, T.classify());
+
+    const char *NegMinNormalStr = "-1.17549435082228750797e-38";
+    EXPECT_FALSE(APFloat(APFloat::FloatTF32(), NegMinNormalStr).isDenormal());
+    APFloat NegT(APFloat::FloatTF32(), NegMinNormalStr);
+    NegT.divide(Val2, rdmd);
+    EXPECT_TRUE(NegT.isDenormal());
+    EXPECT_EQ(fcNegSubnormal, NegT.classify());
+  }
 }
 
 TEST(APFloatTest, IsSmallestNormalized) {
@@ -1350,6 +1370,16 @@ TEST(APFloatTest, makeNaN) {
     {               0x80ULL, APFloat::Float8E4M3B11FNUZ(), false, true,         0xaaULL },
     {               0x80ULL, APFloat::Float8E4M3B11FNUZ(), true, false,         0xaaULL },
     {               0x80ULL, APFloat::Float8E4M3B11FNUZ(), true, true,          0xaaULL },
+    {            0x3fe00ULL, APFloat::FloatTF32(), false, false,          0x00000000ULL },
+    {            0x7fe00ULL, APFloat::FloatTF32(), false,  true,          0x00000000ULL },
+    {            0x3feaaULL, APFloat::FloatTF32(), false, false,                0xaaULL },
+    {            0x3ffaaULL, APFloat::FloatTF32(), false, false,               0xdaaULL },
+    {            0x3ffaaULL, APFloat::FloatTF32(), false, false,              0xfdaaULL },
+    {            0x3fd00ULL, APFloat::FloatTF32(),  true, false,          0x00000000ULL },
+    {            0x7fd00ULL, APFloat::FloatTF32(),  true,  true,          0x00000000ULL },
+    {            0x3fcaaULL, APFloat::FloatTF32(),  true, false,                0xaaULL },
+    {            0x3fdaaULL, APFloat::FloatTF32(),  true, false,               0xfaaULL },
+    {            0x3fdaaULL, APFloat::FloatTF32(),  true, false,               0x1aaULL },
       // clang-format on
   };
 
@@ -1780,6 +1810,8 @@ TEST(APFloatTest, getLargest) {
             APFloat::getLargest(APFloat::Float8E5M2FNUZ()).convertToDouble());
   EXPECT_EQ(
       30, APFloat::getLargest(APFloat::Float8E4M3B11FNUZ()).convertToDouble());
+  EXPECT_EQ(3.40116213421e+38f,
+            APFloat::getLargest(APFloat::FloatTF32()).convertToFloat());
 }
 
 TEST(APFloatTest, getSmallest) {
@@ -1831,6 +1863,13 @@ TEST(APFloatTest, getSmallest) {
   EXPECT_TRUE(test.isFiniteNonZero());
   EXPECT_TRUE(test.isDenormal());
   EXPECT_TRUE(test.bitwiseIsEqual(expected));
+
+  test = APFloat::getSmallest(APFloat::FloatTF32(), true);
+  expected = APFloat(APFloat::FloatTF32(), "-0x0.004p-126");
+  EXPECT_TRUE(test.isNegative());
+  EXPECT_TRUE(test.isFiniteNonZero());
+  EXPECT_TRUE(test.isDenormal());
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
 }
 
 TEST(APFloatTest, getSmallestNormalized) {
@@ -1905,6 +1944,14 @@ TEST(APFloatTest, getSmallestNormalized) {
   EXPECT_FALSE(test.isDenormal());
   EXPECT_TRUE(test.bitwiseIsEqual(expected));
   EXPECT_TRUE(test.isSmallestNormalized());
+
+  test = APFloat::getSmallestNormalized(APFloat::FloatTF32(), false);
+  expected = APFloat(APFloat::FloatTF32(), "0x1p-126");
+  EXPECT_FALSE(test.isNegative());
+  EXPECT_TRUE(test.isFiniteNonZero());
+  EXPECT_FALSE(test.isDenormal());
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+  EXPECT_TRUE(test.isSmallestNormalized());
 }
 
 TEST(APFloatTest, getZero) {
@@ -1936,7 +1983,9 @@ TEST(APFloatTest, getZero) {
       {&APFloat::Float8E4M3FNUZ(), false, false, {0, 0}, 1},
       {&APFloat::Float8E4M3FNUZ(), true, false, {0, 0}, 1},
       {&APFloat::Float8E4M3B11FNUZ(), false, false, {0, 0}, 1},
-      {&APFloat::Float8E4M3B11FNUZ(), true, false, {0, 0}, 1}};
+      {&APFloat::Float8E4M3B11FNUZ(), true, false, {0, 0}, 1},
+      {&APFloat::FloatTF32(), false, true, {0, 0}, 1},
+      {&APFloat::FloatTF32(), true, true, {0x40000ULL, 0}, 1}};
   const unsigned NumGetZeroTests = std::size(GetZeroTest);
   for (unsigned i = 0; i < NumGetZeroTests; ++i) {
     APFloat test = APFloat::getZero(*GetZeroTest[i].semantics,
@@ -6229,6 +6278,34 @@ TEST(APFloatTest, Float8E4M3FNUZToDouble) {
   EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
 }
 
+TEST(APFloatTest, FloatTF32ToDouble) {
+  APFloat One(APFloat::FloatTF32(), "1.0");
+  EXPECT_EQ(1.0, One.convertToDouble());
+  APFloat PosLargest = APFloat::getLargest(APFloat::FloatTF32(), false);
+  EXPECT_EQ(3.401162134214653489792616e+38, PosLargest.convertToDouble());
+  APFloat NegLargest = APFloat::getLargest(APFloat::FloatTF32(), true);
+  EXPECT_EQ(-3.401162134214653489792616e+38, NegLargest.convertToDouble());
+  APFloat PosSmallest =
+      APFloat::getSmallestNormalized(APFloat::FloatTF32(), false);
+  EXPECT_EQ(1.1754943508222875079687e-38, PosSmallest.convertToDouble());
+  APFloat NegSmallest =
+      APFloat::getSmallestNormalized(APFloat::FloatTF32(), true);
+  EXPECT_EQ(-1.1754943508222875079687e-38, NegSmallest.convertToDouble());
+
+  APFloat SmallestDenorm = APFloat::getSmallest(APFloat::FloatTF32(), false);
+  EXPECT_EQ(1.1479437019748901445007e-41, SmallestDenorm.convertToDouble());
+  APFloat LargestDenorm(APFloat::FloatTF32(), "0x1.FF8p-127");
+  EXPECT_EQ(/*0x1.FF8p-127*/ 1.1743464071203126178242e-38,
+            LargestDenorm.convertToDouble());
+
+  APFloat PosInf = APFloat::getInf(APFloat::FloatTF32());
+  EXPECT_EQ(std::numeric_limits<double>::infinity(), PosInf.convertToDouble());
+  APFloat NegInf = APFloat::getInf(APFloat::FloatTF32(), true);
+  EXPECT_EQ(-std::numeric_limits<double>::infinity(), NegInf.convertToDouble());
+  APFloat QNaN = APFloat::getQNaN(APFloat::FloatTF32());
+  EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
+}
+
 TEST(APFloatTest, Float8E5M2FNUZToFloat) {
   APFloat PosZero = APFloat::getZero(APFloat::Float8E5M2FNUZ());
   APFloat PosZeroToFloat(PosZero.convertToFloat());
@@ -6473,4 +6550,40 @@ TEST(APFloatTest, Float8E4M3FNToFloat) {
   EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
 }
 
+TEST(APFloatTest, FloatTF32ToFloat) {
+  APFloat PosZero = APFloat::getZero(APFloat::FloatTF32());
+  APFloat PosZeroToFloat(PosZero.convertToFloat());
+  EXPECT_TRUE(PosZeroToFloat.isPosZero());
+  APFloat NegZero = APFloat::getZero(APFloat::FloatTF32(), true);
+  APFloat NegZeroToFloat(NegZero.convertToFloat());
+  EXPECT_TRUE(NegZeroToFloat.isNegZero());
+
+  APFloat One(APFloat::FloatTF32(), "1.0");
+  EXPECT_EQ(1.0F, One.convertToFloat());
+  APFloat Two(APFloat::FloatTF32(), "2.0");
+  EXPECT_EQ(2.0F, Two.convertToFloat());
+
+  APFloat PosLargest = APFloat::getLargest(APFloat::FloatTF32(), false);
+  EXPECT_EQ(3.40116213421e+38F, PosLargest.convertToFloat());
+
+  APFloat NegLargest = APFloat::getLargest(APFloat::FloatTF32(), true);
+  EXPECT_EQ(-3.40116213421e+38F, NegLargest.convertToFloat());
+
+  APFloat PosSmallest =
+      APFloat::getSmallestNormalized(APFloat::FloatTF32(), false);
+  EXPECT_EQ(/*0x1.p-126*/ 1.1754943508222875e-38F,
+            PosSmallest.convertToFloat());
+  APFloat NegSmallest =
+      APFloat::getSmallestNormalized(APFloat::FloatTF32(), true);
+  EXPECT_EQ(/*-0x1.p-126*/ -1.1754943508222875e-38F,
+            NegSmallest.convertToFloat());
+
+  APFloat SmallestDenorm = APFloat::getSmallest(APFloat::FloatTF32(), false);
+  EXPECT_TRUE(SmallestDenorm.isDenormal());
+  EXPECT_EQ(0x0.004p-126, SmallestDenorm.convertToFloat());
+
+  APFloat QNaN = APFloat::getQNaN(APFloat::FloatTF32());
+  EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
+}
+
 } // namespace


        


More information about the cfe-commits mailing list