[llvm] 55c2211 - [APFloat] Add APFloat semantic support for TF32
Mehdi Amini via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 23 01:55:44 PDT 2023
Author: Jeremy Furtek
Date: 2023-06-23T10:54:49+02:00
New Revision: 55c2211a233e11179048cf58778f40e5a62f444a
URL: https://github.com/llvm/llvm-project/commit/55c2211a233e11179048cf58778f40e5a62f444a
DIFF: https://github.com/llvm/llvm-project/commit/55c2211a233e11179048cf58778f40e5a62f444a.diff
LOG: [APFloat] Add APFloat semantic support for TF32
This diff adds APFloat support for a semantic that matches the TF32 data type
used by some accelerators (most notably GPUs from both NVIDIA and AMD).
For more information on the TF32 data type, see https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/.
Some intrinsics that support the TF32 data type were added in https://reviews.llvm.org/D122044.
For some discussion on supporting common semantics in `APFloat`, see similar
efforts for 8-bit formats at https://reviews.llvm.org/D146441, as well as
https://discourse.llvm.org/t/rfc-adding-the-amd-graphcore-maybe-others-float8-formats-to-apfloat/67969.
A subsequent diff will extend MLIR to use this data type. (Those changes are
not part of this diff to simplify the review process.)
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D151923
Added:
Modified:
clang/lib/AST/MicrosoftMangle.cpp
llvm/include/llvm/ADT/APFloat.h
llvm/lib/Support/APFloat.cpp
llvm/unittests/ADT/APFloatTest.cpp
Removed:
################################################################################
diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp
index 9fede7bbad323..430a57d7b4ec0 100644
--- a/clang/lib/AST/MicrosoftMangle.cpp
+++ b/clang/lib/AST/MicrosoftMangle.cpp
@@ -898,6 +898,7 @@ void MicrosoftCXXNameMangler::mangleFloat(llvm::APFloat Number) {
case APFloat::S_Float8E5M2FNUZ:
case APFloat::S_Float8E4M3FNUZ:
case APFloat::S_Float8E4M3B11FNUZ:
+ case APFloat::S_FloatTF32:
llvm_unreachable("Tried to mangle unexpected APFloat semantics");
}
diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h
index 1875706362f79..64caa5a765456 100644
--- a/llvm/include/llvm/ADT/APFloat.h
+++ b/llvm/include/llvm/ADT/APFloat.h
@@ -184,6 +184,10 @@ struct APFloatBase {
// This format's exponent bias is 11, instead of the 7 (2 ** (4 - 1) - 1)
// that IEEE precedent would imply.
S_Float8E4M3B11FNUZ,
+ // Floating point number that occupies 32 bits or less of storage, providing
+ // improved range compared to half (16-bit) formats, at (potentially)
+ // greater throughput than single precision (32-bit) formats.
+ S_FloatTF32,
S_x87DoubleExtended,
S_MaxSemantics = S_x87DoubleExtended,
@@ -203,6 +207,7 @@ struct APFloatBase {
static const fltSemantics &Float8E4M3FN() LLVM_READNONE;
static const fltSemantics &Float8E4M3FNUZ() LLVM_READNONE;
static const fltSemantics &Float8E4M3B11FNUZ() LLVM_READNONE;
+ static const fltSemantics &FloatTF32() LLVM_READNONE;
static const fltSemantics &x87DoubleExtended() LLVM_READNONE;
/// A Pseudo fltsemantic used to construct APFloats that cannot conflict with
@@ -605,6 +610,7 @@ class IEEEFloat final : public APFloatBase {
APInt convertFloat8E4M3FNAPFloatToAPInt() const;
APInt convertFloat8E4M3FNUZAPFloatToAPInt() const;
APInt convertFloat8E4M3B11FNUZAPFloatToAPInt() const;
+ APInt convertFloatTF32APFloatToAPInt() const;
void initFromAPInt(const fltSemantics *Sem, const APInt &api);
template <const fltSemantics &S> void initFromIEEEAPInt(const APInt &api);
void initFromHalfAPInt(const APInt &api);
@@ -619,6 +625,7 @@ class IEEEFloat final : public APFloatBase {
void initFromFloat8E4M3FNAPInt(const APInt &api);
void initFromFloat8E4M3FNUZAPInt(const APInt &api);
void initFromFloat8E4M3B11FNUZAPInt(const APInt &api);
+ void initFromFloatTF32APInt(const APInt &api);
void assign(const IEEEFloat &);
void copySignificand(const IEEEFloat &);
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index c882c08b256e7..4a73739b5282a 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -138,6 +138,7 @@ static constexpr fltSemantics semFloat8E4M3FNUZ = {
7, -7, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E4M3B11FNUZ = {
4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
+static constexpr fltSemantics semFloatTF32 = {127, -126, 11, 19};
static constexpr fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80};
static constexpr fltSemantics semBogus = {0, 0, 0, 0};
@@ -203,6 +204,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) {
return Float8E4M3FNUZ();
case S_Float8E4M3B11FNUZ:
return Float8E4M3B11FNUZ();
+ case S_FloatTF32:
+ return FloatTF32();
case S_x87DoubleExtended:
return x87DoubleExtended();
}
@@ -233,6 +236,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
return S_Float8E4M3FNUZ;
else if (&Sem == &llvm::APFloat::Float8E4M3B11FNUZ())
return S_Float8E4M3B11FNUZ;
+ else if (&Sem == &llvm::APFloat::FloatTF32())
+ return S_FloatTF32;
else if (&Sem == &llvm::APFloat::x87DoubleExtended())
return S_x87DoubleExtended;
else
@@ -254,6 +259,7 @@ const fltSemantics &APFloatBase::Float8E4M3FNUZ() { return semFloat8E4M3FNUZ; }
const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
return semFloat8E4M3B11FNUZ;
}
+const fltSemantics &APFloatBase::FloatTF32() { return semFloatTF32; }
const fltSemantics &APFloatBase::x87DoubleExtended() {
return semX87DoubleExtended;
}
@@ -3599,6 +3605,11 @@ APInt IEEEFloat::convertFloat8E4M3B11FNUZAPFloatToAPInt() const {
return convertIEEEFloatToAPInt<semFloat8E4M3B11FNUZ>();
}
+APInt IEEEFloat::convertFloatTF32APFloatToAPInt() const {
+ assert(partCount() == 1);
+ return convertIEEEFloatToAPInt<semFloatTF32>();
+}
+
// This function creates an APInt that is just a bit map of the floating
// point constant as it would appear in memory. It is not a conversion,
// and treating the result as a normal integer is unlikely to be useful.
@@ -3637,6 +3648,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3B11FNUZ)
return convertFloat8E4M3B11FNUZAPFloatToAPInt();
+ if (semantics == (const llvm::fltSemantics *)&semFloatTF32)
+ return convertFloatTF32APFloatToAPInt();
+
assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended &&
"unknown format!");
return convertF80LongDoubleAPFloatToAPInt();
@@ -3840,6 +3854,10 @@ void IEEEFloat::initFromFloat8E4M3B11FNUZAPInt(const APInt &api) {
initFromIEEEAPInt<semFloat8E4M3B11FNUZ>(api);
}
+void IEEEFloat::initFromFloatTF32APInt(const APInt &api) {
+ initFromIEEEAPInt<semFloatTF32>(api);
+}
+
/// Treat api as containing the bits of a floating point number.
void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
assert(api.getBitWidth() == Sem->sizeInBits);
@@ -3867,6 +3885,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
return initFromFloat8E4M3FNUZAPInt(api);
if (Sem == &semFloat8E4M3B11FNUZ)
return initFromFloat8E4M3B11FNUZAPInt(api);
+ if (Sem == &semFloatTF32)
+ return initFromFloatTF32APInt(api);
llvm_unreachable(nullptr);
}
diff --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp
index 2d79d34c3104a..c8a5c67feaf80 100644
--- a/llvm/unittests/ADT/APFloatTest.cpp
+++ b/llvm/unittests/ADT/APFloatTest.cpp
@@ -682,6 +682,26 @@ TEST(APFloatTest, Denormal) {
EXPECT_TRUE(T.isDenormal());
EXPECT_EQ(fcPosSubnormal, T.classify());
}
+
+ // Test TF32
+ {
+ const char *MinNormalStr = "1.17549435082228750797e-38";
+ EXPECT_FALSE(APFloat(APFloat::FloatTF32(), MinNormalStr).isDenormal());
+ EXPECT_FALSE(APFloat(APFloat::FloatTF32(), 0).isDenormal());
+
+ APFloat Val2(APFloat::FloatTF32(), 2);
+ APFloat T(APFloat::FloatTF32(), MinNormalStr);
+ T.divide(Val2, rdmd);
+ EXPECT_TRUE(T.isDenormal());
+ EXPECT_EQ(fcPosSubnormal, T.classify());
+
+ const char *NegMinNormalStr = "-1.17549435082228750797e-38";
+ EXPECT_FALSE(APFloat(APFloat::FloatTF32(), NegMinNormalStr).isDenormal());
+ APFloat NegT(APFloat::FloatTF32(), NegMinNormalStr);
+ NegT.divide(Val2, rdmd);
+ EXPECT_TRUE(NegT.isDenormal());
+ EXPECT_EQ(fcNegSubnormal, NegT.classify());
+ }
}
TEST(APFloatTest, IsSmallestNormalized) {
@@ -1350,6 +1370,16 @@ TEST(APFloatTest, makeNaN) {
{ 0x80ULL, APFloat::Float8E4M3B11FNUZ(), false, true, 0xaaULL },
{ 0x80ULL, APFloat::Float8E4M3B11FNUZ(), true, false, 0xaaULL },
{ 0x80ULL, APFloat::Float8E4M3B11FNUZ(), true, true, 0xaaULL },
+ { 0x3fe00ULL, APFloat::FloatTF32(), false, false, 0x00000000ULL },
+ { 0x7fe00ULL, APFloat::FloatTF32(), false, true, 0x00000000ULL },
+ { 0x3feaaULL, APFloat::FloatTF32(), false, false, 0xaaULL },
+ { 0x3ffaaULL, APFloat::FloatTF32(), false, false, 0xdaaULL },
+ { 0x3ffaaULL, APFloat::FloatTF32(), false, false, 0xfdaaULL },
+ { 0x3fd00ULL, APFloat::FloatTF32(), true, false, 0x00000000ULL },
+ { 0x7fd00ULL, APFloat::FloatTF32(), true, true, 0x00000000ULL },
+ { 0x3fcaaULL, APFloat::FloatTF32(), true, false, 0xaaULL },
+ { 0x3fdaaULL, APFloat::FloatTF32(), true, false, 0xfaaULL },
+ { 0x3fdaaULL, APFloat::FloatTF32(), true, false, 0x1aaULL },
// clang-format on
};
@@ -1780,6 +1810,8 @@ TEST(APFloatTest, getLargest) {
APFloat::getLargest(APFloat::Float8E5M2FNUZ()).convertToDouble());
EXPECT_EQ(
30, APFloat::getLargest(APFloat::Float8E4M3B11FNUZ()).convertToDouble());
+ EXPECT_EQ(3.40116213421e+38f,
+ APFloat::getLargest(APFloat::FloatTF32()).convertToFloat());
}
TEST(APFloatTest, getSmallest) {
@@ -1831,6 +1863,13 @@ TEST(APFloatTest, getSmallest) {
EXPECT_TRUE(test.isFiniteNonZero());
EXPECT_TRUE(test.isDenormal());
EXPECT_TRUE(test.bitwiseIsEqual(expected));
+
+ test = APFloat::getSmallest(APFloat::FloatTF32(), true);
+ expected = APFloat(APFloat::FloatTF32(), "-0x0.004p-126");
+ EXPECT_TRUE(test.isNegative());
+ EXPECT_TRUE(test.isFiniteNonZero());
+ EXPECT_TRUE(test.isDenormal());
+ EXPECT_TRUE(test.bitwiseIsEqual(expected));
}
TEST(APFloatTest, getSmallestNormalized) {
@@ -1905,6 +1944,14 @@ TEST(APFloatTest, getSmallestNormalized) {
EXPECT_FALSE(test.isDenormal());
EXPECT_TRUE(test.bitwiseIsEqual(expected));
EXPECT_TRUE(test.isSmallestNormalized());
+
+ test = APFloat::getSmallestNormalized(APFloat::FloatTF32(), false);
+ expected = APFloat(APFloat::FloatTF32(), "0x1p-126");
+ EXPECT_FALSE(test.isNegative());
+ EXPECT_TRUE(test.isFiniteNonZero());
+ EXPECT_FALSE(test.isDenormal());
+ EXPECT_TRUE(test.bitwiseIsEqual(expected));
+ EXPECT_TRUE(test.isSmallestNormalized());
}
TEST(APFloatTest, getZero) {
@@ -1936,7 +1983,9 @@ TEST(APFloatTest, getZero) {
{&APFloat::Float8E4M3FNUZ(), false, false, {0, 0}, 1},
{&APFloat::Float8E4M3FNUZ(), true, false, {0, 0}, 1},
{&APFloat::Float8E4M3B11FNUZ(), false, false, {0, 0}, 1},
- {&APFloat::Float8E4M3B11FNUZ(), true, false, {0, 0}, 1}};
+ {&APFloat::Float8E4M3B11FNUZ(), true, false, {0, 0}, 1},
+ {&APFloat::FloatTF32(), false, true, {0, 0}, 1},
+ {&APFloat::FloatTF32(), true, true, {0x40000ULL, 0}, 1}};
const unsigned NumGetZeroTests = std::size(GetZeroTest);
for (unsigned i = 0; i < NumGetZeroTests; ++i) {
APFloat test = APFloat::getZero(*GetZeroTest[i].semantics,
@@ -6229,6 +6278,34 @@ TEST(APFloatTest, Float8E4M3FNUZToDouble) {
EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
}
+TEST(APFloatTest, FloatTF32ToDouble) {
+ APFloat One(APFloat::FloatTF32(), "1.0");
+ EXPECT_EQ(1.0, One.convertToDouble());
+ APFloat PosLargest = APFloat::getLargest(APFloat::FloatTF32(), false);
+ EXPECT_EQ(3.401162134214653489792616e+38, PosLargest.convertToDouble());
+ APFloat NegLargest = APFloat::getLargest(APFloat::FloatTF32(), true);
+ EXPECT_EQ(-3.401162134214653489792616e+38, NegLargest.convertToDouble());
+ APFloat PosSmallest =
+ APFloat::getSmallestNormalized(APFloat::FloatTF32(), false);
+ EXPECT_EQ(1.1754943508222875079687e-38, PosSmallest.convertToDouble());
+ APFloat NegSmallest =
+ APFloat::getSmallestNormalized(APFloat::FloatTF32(), true);
+ EXPECT_EQ(-1.1754943508222875079687e-38, NegSmallest.convertToDouble());
+
+ APFloat SmallestDenorm = APFloat::getSmallest(APFloat::FloatTF32(), false);
+ EXPECT_EQ(1.1479437019748901445007e-41, SmallestDenorm.convertToDouble());
+ APFloat LargestDenorm(APFloat::FloatTF32(), "0x1.FF8p-127");
+ EXPECT_EQ(/*0x1.FF8p-127*/ 1.1743464071203126178242e-38,
+ LargestDenorm.convertToDouble());
+
+ APFloat PosInf = APFloat::getInf(APFloat::FloatTF32());
+ EXPECT_EQ(std::numeric_limits<double>::infinity(), PosInf.convertToDouble());
+ APFloat NegInf = APFloat::getInf(APFloat::FloatTF32(), true);
+ EXPECT_EQ(-std::numeric_limits<double>::infinity(), NegInf.convertToDouble());
+ APFloat QNaN = APFloat::getQNaN(APFloat::FloatTF32());
+ EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
+}
+
TEST(APFloatTest, Float8E5M2FNUZToFloat) {
APFloat PosZero = APFloat::getZero(APFloat::Float8E5M2FNUZ());
APFloat PosZeroToFloat(PosZero.convertToFloat());
@@ -6473,4 +6550,40 @@ TEST(APFloatTest, Float8E4M3FNToFloat) {
EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
}
+TEST(APFloatTest, FloatTF32ToFloat) {
+ APFloat PosZero = APFloat::getZero(APFloat::FloatTF32());
+ APFloat PosZeroToFloat(PosZero.convertToFloat());
+ EXPECT_TRUE(PosZeroToFloat.isPosZero());
+ APFloat NegZero = APFloat::getZero(APFloat::FloatTF32(), true);
+ APFloat NegZeroToFloat(NegZero.convertToFloat());
+ EXPECT_TRUE(NegZeroToFloat.isNegZero());
+
+ APFloat One(APFloat::FloatTF32(), "1.0");
+ EXPECT_EQ(1.0F, One.convertToFloat());
+ APFloat Two(APFloat::FloatTF32(), "2.0");
+ EXPECT_EQ(2.0F, Two.convertToFloat());
+
+ APFloat PosLargest = APFloat::getLargest(APFloat::FloatTF32(), false);
+ EXPECT_EQ(3.40116213421e+38F, PosLargest.convertToFloat());
+
+ APFloat NegLargest = APFloat::getLargest(APFloat::FloatTF32(), true);
+ EXPECT_EQ(-3.40116213421e+38F, NegLargest.convertToFloat());
+
+ APFloat PosSmallest =
+ APFloat::getSmallestNormalized(APFloat::FloatTF32(), false);
+ EXPECT_EQ(/*0x1.p-126*/ 1.1754943508222875e-38F,
+ PosSmallest.convertToFloat());
+ APFloat NegSmallest =
+ APFloat::getSmallestNormalized(APFloat::FloatTF32(), true);
+ EXPECT_EQ(/*-0x1.p-126*/ -1.1754943508222875e-38F,
+ NegSmallest.convertToFloat());
+
+ APFloat SmallestDenorm = APFloat::getSmallest(APFloat::FloatTF32(), false);
+ EXPECT_TRUE(SmallestDenorm.isDenormal());
+ EXPECT_EQ(0x0.004p-126, SmallestDenorm.convertToFloat());
+
+ APFloat QNaN = APFloat::getQNaN(APFloat::FloatTF32());
+ EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
+}
+
} // namespace
More information about the llvm-commits
mailing list