[Mlir-commits] [mlir] 2f086f2 - [APFloat] Add E4M3B11FNUZ
David Majnemer
llvmlistbot at llvm.org
Fri Mar 24 13:07:15 PDT 2023
Author: David Majnemer
Date: 2023-03-24T20:06:40Z
New Revision: 2f086f265bf97fe6543fb199f4ef874ca3522479
URL: https://github.com/llvm/llvm-project/commit/2f086f265bf97fe6543fb199f4ef874ca3522479
DIFF: https://github.com/llvm/llvm-project/commit/2f086f265bf97fe6543fb199f4ef874ca3522479.diff
LOG: [APFloat] Add E4M3B11FNUZ
X. Sun et al. (https://dl.acm.org/doi/10.5555/3454287.3454728) published
a paper showing that an FP format with 4 bits of exponent, 3 bits of
significand and an exponent bias of 11 would work quite well for ML
applications.
Google hardware supports a variant of this format where 0x80 is used to
represent NaN, as in the Float8E4M3FNUZ format. Just like the
Float8E4M3FNUZ format, this format does not support -0 and values which
would map to it will become +0.
This format is proposed for inclusion in OpenXLA's StableHLO dialect: https://github.com/openxla/stablehlo/pull/1308
As part of inclusion in that dialect, APFloat needs to know how to
handle this format.
Differential Revision: https://reviews.llvm.org/D146441
Added:
Modified:
clang/lib/AST/MicrosoftMangle.cpp
llvm/include/llvm/ADT/APFloat.h
llvm/lib/Support/APFloat.cpp
llvm/unittests/ADT/APFloatTest.cpp
mlir/include/mlir-c/BuiltinTypes.h
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/BuiltinTypes.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/Types.h
mlir/lib/AsmParser/TokenKinds.def
mlir/lib/AsmParser/TypeParser.cpp
mlir/lib/Bindings/Python/IRTypes.cpp
mlir/lib/CAPI/IR/BuiltinTypes.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Types.cpp
mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
mlir/test/IR/attribute.mlir
mlir/test/python/ir/builtin_types.py
mlir/utils/lldb-scripts/mlirDataFormatters.py
Removed:
################################################################################
diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp
index b8a916a72f750..d8c837bcade02 100644
--- a/clang/lib/AST/MicrosoftMangle.cpp
+++ b/clang/lib/AST/MicrosoftMangle.cpp
@@ -845,6 +845,7 @@ void MicrosoftCXXNameMangler::mangleFloat(llvm::APFloat Number) {
case APFloat::S_Float8E4M3FN:
case APFloat::S_Float8E5M2FNUZ:
case APFloat::S_Float8E4M3FNUZ:
+ case APFloat::S_Float8E4M3B11FNUZ:
llvm_unreachable("Tried to mangle unexpected APFloat semantics");
}
diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h
index c3064651fbc56..d7fcf02decea7 100644
--- a/llvm/include/llvm/ADT/APFloat.h
+++ b/llvm/include/llvm/ADT/APFloat.h
@@ -177,6 +177,13 @@ struct APFloatBase {
// This format's exponent bias is 8, instead of the 7 (2 ** (4 - 1) - 1)
// that IEEE precedent would imply.
S_Float8E4M3FNUZ,
+ // 8-bit floating point number mostly following IEEE-754 conventions
+ // and bit layout S1E4M3 with expanded range and with no infinity or signed
+ // zero.
+ // NaN is represnted as negative zero. (FN -> Finite, UZ -> unsigned zero).
+ // This format's exponent bias is 11, instead of the 7 (2 ** (4 - 1) - 1)
+ // that IEEE precedent would imply.
+ S_Float8E4M3B11FNUZ,
S_x87DoubleExtended,
S_MaxSemantics = S_x87DoubleExtended,
@@ -195,6 +202,7 @@ struct APFloatBase {
static const fltSemantics &Float8E5M2FNUZ() LLVM_READNONE;
static const fltSemantics &Float8E4M3FN() LLVM_READNONE;
static const fltSemantics &Float8E4M3FNUZ() LLVM_READNONE;
+ static const fltSemantics &Float8E4M3B11FNUZ() LLVM_READNONE;
static const fltSemantics &x87DoubleExtended() LLVM_READNONE;
/// A Pseudo fltsemantic used to construct APFloats that cannot conflict with
@@ -590,6 +598,7 @@ class IEEEFloat final : public APFloatBase {
APInt convertFloat8E5M2FNUZAPFloatToAPInt() const;
APInt convertFloat8E4M3FNAPFloatToAPInt() const;
APInt convertFloat8E4M3FNUZAPFloatToAPInt() const;
+ APInt convertFloat8E4M3B11FNUZAPFloatToAPInt() const;
void initFromAPInt(const fltSemantics *Sem, const APInt &api);
void initFromHalfAPInt(const APInt &api);
void initFromBFloatAPInt(const APInt &api);
@@ -602,6 +611,7 @@ class IEEEFloat final : public APFloatBase {
void initFromFloat8E5M2FNUZAPInt(const APInt &api);
void initFromFloat8E4M3FNAPInt(const APInt &api);
void initFromFloat8E4M3FNUZAPInt(const APInt &api);
+ void initFromFloat8E4M3B11FNUZAPInt(const APInt &api);
void assign(const IEEEFloat &);
void copySignificand(const IEEEFloat &);
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index 05053828bfc18..97c811a18e8af 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -60,8 +60,9 @@ enum class fltNonfiniteBehavior {
IEEE754,
// This behavior is present in the Float8ExMyFN* types (Float8E4M3FN,
- // Float8E5M2FNUZ, and Float8E4M3FNUZ). There is no representation for Inf,
- // and operations that would ordinarily produce Inf produce NaN instead.
+ // Float8E5M2FNUZ, Float8E4M3FNUZ, and Float8E4M3B11FNUZ). There is no
+ // representation for Inf, and operations that would ordinarily produce Inf
+ // produce NaN instead.
// The details of the NaN representation(s) in this form are determined by the
// `fltNanEncoding` enum. We treat all NaNs as quiet, as the available
// encodings do not distinguish between signalling and quiet NaN.
@@ -138,6 +139,13 @@ struct fltSemantics {
8, -6, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes};
static const fltSemantics semFloat8E4M3FNUZ = {
7, -7, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
+ static const fltSemantics semFloat8E4M3B11FNUZ = {
+ 4,
+ -10,
+ 4,
+ 8,
+ fltNonfiniteBehavior::NanOnly,
+ fltNanEncoding::NegativeZero};
static const fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80};
static const fltSemantics semBogus = {0, 0, 0, 0};
@@ -201,6 +209,8 @@ struct fltSemantics {
return Float8E4M3FN();
case S_Float8E4M3FNUZ:
return Float8E4M3FNUZ();
+ case S_Float8E4M3B11FNUZ:
+ return Float8E4M3B11FNUZ();
case S_x87DoubleExtended:
return x87DoubleExtended();
}
@@ -229,6 +239,8 @@ struct fltSemantics {
return S_Float8E4M3FN;
else if (&Sem == &llvm::APFloat::Float8E4M3FNUZ())
return S_Float8E4M3FNUZ;
+ else if (&Sem == &llvm::APFloat::Float8E4M3B11FNUZ())
+ return S_Float8E4M3B11FNUZ;
else if (&Sem == &llvm::APFloat::x87DoubleExtended())
return S_x87DoubleExtended;
else
@@ -259,6 +271,9 @@ struct fltSemantics {
const fltSemantics &APFloatBase::Float8E4M3FNUZ() {
return semFloat8E4M3FNUZ;
}
+ const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
+ return semFloat8E4M3B11FNUZ;
+ }
const fltSemantics &APFloatBase::x87DoubleExtended() {
return semX87DoubleExtended;
}
@@ -3709,6 +3724,33 @@ APInt IEEEFloat::convertFloat8E4M3FNUZAPFloatToAPInt() const {
(mysignificand & 0x7)));
}
+APInt IEEEFloat::convertFloat8E4M3B11FNUZAPFloatToAPInt() const {
+ assert(semantics == (const llvm::fltSemantics *)&semFloat8E4M3B11FNUZ);
+ assert(partCount() == 1);
+
+ uint32_t myexponent, mysignificand;
+
+ if (isFiniteNonZero()) {
+ myexponent = exponent + 11; // bias
+ mysignificand = (uint32_t)*significandParts();
+ if (myexponent == 1 && !(mysignificand & 0x8))
+ myexponent = 0; // denormal
+ } else if (category == fcZero) {
+ myexponent = 0;
+ mysignificand = 0;
+ } else if (category == fcInfinity) {
+ myexponent = 0;
+ mysignificand = 0;
+ } else {
+ assert(category == fcNaN && "Unknown category!");
+ myexponent = 0;
+ mysignificand = (uint32_t)*significandParts();
+ }
+
+ return APInt(8, (((sign & 1) << 7) | ((myexponent & 0xf) << 3) |
+ (mysignificand & 0x7)));
+}
+
// This function creates an APInt that is just a bit map of the floating
// point constant as it would appear in memory. It is not a conversion,
// and treating the result as a normal integer is unlikely to be useful.
@@ -3744,6 +3786,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FNUZ)
return convertFloat8E4M3FNUZAPFloatToAPInt();
+ if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3B11FNUZ)
+ return convertFloat8E4M3B11FNUZAPFloatToAPInt();
+
assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended &&
"unknown format!");
return convertF80LongDoubleAPFloatToAPInt();
@@ -4077,6 +4122,32 @@ void IEEEFloat::initFromFloat8E4M3FNUZAPInt(const APInt &api) {
}
}
+void IEEEFloat::initFromFloat8E4M3B11FNUZAPInt(const APInt &api) {
+ uint32_t i = (uint32_t)*api.getRawData();
+ uint32_t myexponent = (i >> 3) & 0xf;
+ uint32_t mysignificand = i & 0x7;
+
+ initialize(&semFloat8E4M3B11FNUZ);
+ assert(partCount() == 1);
+
+ sign = i >> 7;
+ if (myexponent == 0 && mysignificand == 0 && sign == 0) {
+ makeZero(sign);
+ } else if (myexponent == 0 && mysignificand == 0 && sign == 1) {
+ category = fcNaN;
+ exponent = exponentNaN();
+ *significandParts() = mysignificand;
+ } else {
+ category = fcNormal;
+ exponent = myexponent - 11; // bias
+ *significandParts() = mysignificand;
+ if (myexponent == 0) // denormal
+ exponent = -10;
+ else
+ *significandParts() |= 0x8; // integer bit
+ }
+}
+
/// Treat api as containing the bits of a floating point number.
void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
assert(api.getBitWidth() == Sem->sizeInBits);
@@ -4102,6 +4173,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
return initFromFloat8E4M3FNAPInt(api);
if (Sem == &semFloat8E4M3FNUZ)
return initFromFloat8E4M3FNUZAPInt(api);
+ if (Sem == &semFloat8E4M3B11FNUZ)
+ return initFromFloat8E4M3B11FNUZAPInt(api);
llvm_unreachable(nullptr);
}
diff --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp
index 79c6a3f2c53ee..cbf59acd77ee0 100644
--- a/llvm/unittests/ADT/APFloatTest.cpp
+++ b/llvm/unittests/ADT/APFloatTest.cpp
@@ -1346,6 +1346,10 @@ TEST(APFloatTest, makeNaN) {
{ 0x80ULL, APFloat::Float8E4M3FNUZ(), false, true, 0xaaULL },
{ 0x80ULL, APFloat::Float8E4M3FNUZ(), true, false, 0xaaULL },
{ 0x80ULL, APFloat::Float8E4M3FNUZ(), true, true, 0xaaULL },
+ { 0x80ULL, APFloat::Float8E4M3B11FNUZ(), false, false, 0xaaULL },
+ { 0x80ULL, APFloat::Float8E4M3B11FNUZ(), false, true, 0xaaULL },
+ { 0x80ULL, APFloat::Float8E4M3B11FNUZ(), true, false, 0xaaULL },
+ { 0x80ULL, APFloat::Float8E4M3B11FNUZ(), true, true, 0xaaULL },
// clang-format on
};
@@ -1774,6 +1778,8 @@ TEST(APFloatTest, getLargest) {
APFloat::getLargest(APFloat::Float8E4M3FNUZ()).convertToDouble());
EXPECT_EQ(57344,
APFloat::getLargest(APFloat::Float8E5M2FNUZ()).convertToDouble());
+ EXPECT_EQ(
+ 30, APFloat::getLargest(APFloat::Float8E4M3B11FNUZ()).convertToDouble());
}
TEST(APFloatTest, getSmallest) {
@@ -1818,6 +1824,13 @@ TEST(APFloatTest, getSmallest) {
EXPECT_TRUE(test.isFiniteNonZero());
EXPECT_TRUE(test.isDenormal());
EXPECT_TRUE(test.bitwiseIsEqual(expected));
+
+ test = APFloat::getSmallest(APFloat::Float8E4M3B11FNUZ(), false);
+ expected = APFloat(APFloat::Float8E4M3B11FNUZ(), "0x0.2p-10");
+ EXPECT_FALSE(test.isNegative());
+ EXPECT_TRUE(test.isFiniteNonZero());
+ EXPECT_TRUE(test.isDenormal());
+ EXPECT_TRUE(test.bitwiseIsEqual(expected));
}
TEST(APFloatTest, getSmallestNormalized) {
@@ -1884,6 +1897,14 @@ TEST(APFloatTest, getSmallestNormalized) {
EXPECT_FALSE(test.isDenormal());
EXPECT_TRUE(test.bitwiseIsEqual(expected));
EXPECT_TRUE(test.isSmallestNormalized());
+
+ test = APFloat::getSmallestNormalized(APFloat::Float8E4M3B11FNUZ(), false);
+ expected = APFloat(APFloat::Float8E4M3B11FNUZ(), "0x1.0p-10");
+ EXPECT_FALSE(test.isNegative());
+ EXPECT_TRUE(test.isFiniteNonZero());
+ EXPECT_FALSE(test.isDenormal());
+ EXPECT_TRUE(test.bitwiseIsEqual(expected));
+ EXPECT_TRUE(test.isSmallestNormalized());
}
TEST(APFloatTest, getZero) {
@@ -1913,7 +1934,9 @@ TEST(APFloatTest, getZero) {
{&APFloat::Float8E4M3FN(), false, true, {0, 0}, 1},
{&APFloat::Float8E4M3FN(), true, true, {0x80ULL, 0}, 1},
{&APFloat::Float8E4M3FNUZ(), false, false, {0, 0}, 1},
- {&APFloat::Float8E4M3FNUZ(), true, false, {0, 0}, 1}};
+ {&APFloat::Float8E4M3FNUZ(), true, false, {0, 0}, 1},
+ {&APFloat::Float8E4M3B11FNUZ(), false, false, {0, 0}, 1},
+ {&APFloat::Float8E4M3B11FNUZ(), true, false, {0, 0}, 1}};
const unsigned NumGetZeroTests = std::size(GetZeroTest);
for (unsigned i = 0; i < NumGetZeroTests; ++i) {
APFloat test = APFloat::getZero(*GetZeroTest[i].semantics,
@@ -1944,14 +1967,14 @@ TEST(APFloatTest, copySign) {
EXPECT_TRUE(APFloat(42.0).bitwiseIsEqual(
APFloat::copySign(APFloat(42.0), APFloat(1.0))));
// For floating-point formats with unsigned 0, copySign() to a zero is a noop
- EXPECT_TRUE(
- APFloat::getZero(APFloat::Float8E4M3FNUZ())
- .bitwiseIsEqual(APFloat::copySign(
- APFloat::getZero(APFloat::Float8E4M3FNUZ()), APFloat(-1.0))));
- EXPECT_TRUE(
- APFloat::getNaN(APFloat::Float8E4M3FNUZ(), true)
- .bitwiseIsEqual(APFloat::copySign(
- APFloat::getNaN(APFloat::Float8E4M3FNUZ(), true), APFloat(1.0))));
+ for (APFloat::Semantics S :
+ {APFloat::S_Float8E4M3FNUZ, APFloat::S_Float8E4M3B11FNUZ}) {
+ const llvm::fltSemantics &Sem = APFloat::EnumToSemantics(S);
+ EXPECT_TRUE(APFloat::getZero(Sem).bitwiseIsEqual(
+ APFloat::copySign(APFloat::getZero(Sem), APFloat(-1.0))));
+ EXPECT_TRUE(APFloat::getNaN(Sem, true).bitwiseIsEqual(
+ APFloat::copySign(APFloat::getNaN(Sem, true), APFloat(1.0))));
+ }
}
TEST(APFloatTest, convert) {
@@ -2073,17 +2096,18 @@ TEST(APFloatTest, Float8UZConvert) {
{APFloat::getSNaN(APFloat::IEEEsingle(), true), APFloat::opInvalidOp},
{APFloat::getInf(APFloat::IEEEsingle(), false), APFloat::opInexact},
{APFloat::getInf(APFloat::IEEEsingle(), true), APFloat::opInexact}};
- for (auto [toTest, expectedRes] : toNaNTests) {
- llvm::SmallString<16> value;
- toTest.toString(value);
- SCOPED_TRACE("toTest = " + value);
- for (const fltSemantics *sem :
- {&APFloat::Float8E4M3FNUZ(), &APFloat::Float8E5M2FNUZ()}) {
- SCOPED_TRACE("Semantics = " +
- std::to_string(APFloat::SemanticsToEnum(*sem)));
+ for (APFloat::Semantics S :
+ {APFloat::S_Float8E5M2FNUZ, APFloat::S_Float8E4M3FNUZ,
+ APFloat::S_Float8E4M3B11FNUZ}) {
+ const llvm::fltSemantics &Sem = APFloat::EnumToSemantics(S);
+ SCOPED_TRACE("Semantics = " + std::to_string(S));
+ for (auto [toTest, expectedRes] : toNaNTests) {
+ llvm::SmallString<16> value;
+ toTest.toString(value);
+ SCOPED_TRACE("toTest = " + value);
losesInfo = false;
APFloat test = toTest;
- EXPECT_EQ(test.convert(*sem, APFloat::rmNearestTiesToAway, &losesInfo),
+ EXPECT_EQ(test.convert(Sem, APFloat::rmNearestTiesToAway, &losesInfo),
expectedRes);
EXPECT_TRUE(test.isNaN());
EXPECT_TRUE(test.isNegative());
@@ -2092,37 +2116,34 @@ TEST(APFloatTest, Float8UZConvert) {
EXPECT_EQ(0x80, test.bitcastToAPInt());
EXPECT_TRUE(losesInfo);
}
- }
- // Negative zero conversions are information losing.
- losesInfo = false;
- APFloat test = APFloat::getZero(APFloat::IEEEsingle(), true);
- EXPECT_EQ(test.convert(APFloat::Float8E5M2FNUZ(),
- APFloat::rmNearestTiesToAway, &losesInfo),
- APFloat::opInexact);
- EXPECT_TRUE(test.isZero());
- EXPECT_FALSE(test.isNegative());
- EXPECT_TRUE(losesInfo);
- EXPECT_EQ(0x0, test.bitcastToAPInt());
-
- losesInfo = true;
- test = APFloat::getZero(APFloat::IEEEsingle(), false);
- EXPECT_EQ(test.convert(APFloat::Float8E5M2FNUZ(),
- APFloat::rmNearestTiesToAway, &losesInfo),
- APFloat::opOK);
- EXPECT_TRUE(test.isZero());
- EXPECT_FALSE(test.isNegative());
- EXPECT_FALSE(losesInfo);
- EXPECT_EQ(0x0, test.bitcastToAPInt());
+ // Negative zero conversions are information losing.
+ losesInfo = false;
+ APFloat test = APFloat::getZero(APFloat::IEEEsingle(), true);
+ EXPECT_EQ(test.convert(Sem, APFloat::rmNearestTiesToAway, &losesInfo),
+ APFloat::opInexact);
+ EXPECT_TRUE(test.isZero());
+ EXPECT_FALSE(test.isNegative());
+ EXPECT_TRUE(losesInfo);
+ EXPECT_EQ(0x0, test.bitcastToAPInt());
+
+ losesInfo = true;
+ test = APFloat::getZero(APFloat::IEEEsingle(), false);
+ EXPECT_EQ(test.convert(Sem, APFloat::rmNearestTiesToAway, &losesInfo),
+ APFloat::opOK);
+ EXPECT_TRUE(test.isZero());
+ EXPECT_FALSE(test.isNegative());
+ EXPECT_FALSE(losesInfo);
+ EXPECT_EQ(0x0, test.bitcastToAPInt());
- // Except in casts between ourselves.
- losesInfo = true;
- test = APFloat::getZero(APFloat::Float8E5M2FNUZ());
- EXPECT_EQ(test.convert(APFloat::Float8E4M3FNUZ(),
- APFloat::rmNearestTiesToAway, &losesInfo),
- APFloat::opOK);
- EXPECT_FALSE(losesInfo);
- EXPECT_EQ(0x0, test.bitcastToAPInt());
+ // Except in casts between ourselves.
+ losesInfo = true;
+ test = APFloat::getZero(Sem);
+ EXPECT_EQ(test.convert(Sem, APFloat::rmNearestTiesToAway, &losesInfo),
+ APFloat::opOK);
+ EXPECT_FALSE(losesInfo);
+ EXPECT_EQ(0x0, test.bitcastToAPInt());
+ }
}
TEST(APFloatTest, PPCDoubleDouble) {
@@ -5003,7 +5024,7 @@ TEST(APFloatTest, Float8ExhaustivePair) {
// Test each pair of 8-bit floats with non-standard semantics
for (APFloat::Semantics Sem :
{APFloat::S_Float8E4M3FN, APFloat::S_Float8E5M2FNUZ,
- APFloat::S_Float8E4M3FNUZ}) {
+ APFloat::S_Float8E4M3FNUZ, APFloat::S_Float8E4M3B11FNUZ}) {
const llvm::fltSemantics &S = APFloat::EnumToSemantics(Sem);
for (int i = 0; i < 256; i++) {
for (int j = 0; j < 256; j++) {
@@ -5483,50 +5504,54 @@ TEST(APFloatTest, UnsignedZeroArithmeticSpecial) {
// cases and so are not repeated here.
// The IEEE round towards negative rule doesn't apply
- APFloat test = APFloat::getSmallest(APFloat::Float8E4M3FNUZ());
- APFloat rhs = test;
- EXPECT_EQ(test.subtract(rhs, APFloat::rmTowardNegative), APFloat::opOK);
- EXPECT_TRUE(test.isZero());
- EXPECT_FALSE(test.isNegative());
-
- // Multiplication of (small) * (-small) is +0
- test = APFloat::getSmallestNormalized(APFloat::Float8E4M3FNUZ());
- rhs = -test;
- EXPECT_EQ(test.multiply(rhs, APFloat::rmNearestTiesToAway),
- APFloat::opInexact | APFloat::opUnderflow);
- EXPECT_TRUE(test.isZero());
- EXPECT_FALSE(test.isNegative());
+ for (APFloat::Semantics S :
+ {APFloat::S_Float8E4M3FNUZ, APFloat::S_Float8E4M3B11FNUZ}) {
+ const llvm::fltSemantics &Sem = APFloat::EnumToSemantics(S);
+ APFloat test = APFloat::getSmallest(Sem);
+ APFloat rhs = test;
+ EXPECT_EQ(test.subtract(rhs, APFloat::rmTowardNegative), APFloat::opOK);
+ EXPECT_TRUE(test.isZero());
+ EXPECT_FALSE(test.isNegative());
- // Dividing the negatize float_min by anything gives +0
- test = APFloat::getSmallest(APFloat::Float8E4M3FNUZ(), true);
- rhs = APFloat(APFloat::Float8E4M3FNUZ(), "2.0");
- EXPECT_EQ(test.divide(rhs, APFloat::rmNearestTiesToEven),
- APFloat::opInexact | APFloat::opUnderflow);
- EXPECT_TRUE(test.isZero());
- EXPECT_FALSE(test.isNegative());
+ // Multiplication of (small) * (-small) is +0
+ test = APFloat::getSmallestNormalized(Sem);
+ rhs = -test;
+ EXPECT_EQ(test.multiply(rhs, APFloat::rmNearestTiesToAway),
+ APFloat::opInexact | APFloat::opUnderflow);
+ EXPECT_TRUE(test.isZero());
+ EXPECT_FALSE(test.isNegative());
- // Remainder can't copy sign because there's only one zero
- test = APFloat(APFloat::Float8E4M3FNUZ(), "-4.0");
- rhs = APFloat(APFloat::Float8E4M3FNUZ(), "2.0");
- EXPECT_EQ(test.remainder(rhs), APFloat::opOK);
- EXPECT_TRUE(test.isZero());
- EXPECT_FALSE(test.isNegative());
+ // Dividing the negatize float_min by anything gives +0
+ test = APFloat::getSmallest(Sem, true);
+ rhs = APFloat(Sem, "2.0");
+ EXPECT_EQ(test.divide(rhs, APFloat::rmNearestTiesToEven),
+ APFloat::opInexact | APFloat::opUnderflow);
+ EXPECT_TRUE(test.isZero());
+ EXPECT_FALSE(test.isNegative());
- // And same for mod
- test = APFloat(APFloat::Float8E4M3FNUZ(), "-4.0");
- rhs = APFloat(APFloat::Float8E4M3FNUZ(), "2.0");
- EXPECT_EQ(test.mod(rhs), APFloat::opOK);
- EXPECT_TRUE(test.isZero());
- EXPECT_FALSE(test.isNegative());
+ // Remainder can't copy sign because there's only one zero
+ test = APFloat(Sem, "-4.0");
+ rhs = APFloat(Sem, "2.0");
+ EXPECT_EQ(test.remainder(rhs), APFloat::opOK);
+ EXPECT_TRUE(test.isZero());
+ EXPECT_FALSE(test.isNegative());
- // FMA correctly handles both the multiply and add parts of all this
- test = APFloat(APFloat::Float8E4M3FNUZ(), "2.0");
- rhs = test;
- APFloat addend = APFloat(APFloat::Float8E4M3FNUZ(), "-4.0");
- EXPECT_EQ(test.fusedMultiplyAdd(rhs, addend, APFloat::rmTowardNegative),
- APFloat::opOK);
- EXPECT_TRUE(test.isZero());
- EXPECT_FALSE(test.isNegative());
+ // And same for mod
+ test = APFloat(Sem, "-4.0");
+ rhs = APFloat(Sem, "2.0");
+ EXPECT_EQ(test.mod(rhs), APFloat::opOK);
+ EXPECT_TRUE(test.isZero());
+ EXPECT_FALSE(test.isNegative());
+
+ // FMA correctly handles both the multiply and add parts of all this
+ test = APFloat(Sem, "2.0");
+ rhs = test;
+ APFloat addend = APFloat(Sem, "-4.0");
+ EXPECT_EQ(test.fusedMultiplyAdd(rhs, addend, APFloat::rmTowardNegative),
+ APFloat::opOK);
+ EXPECT_TRUE(test.isZero());
+ EXPECT_FALSE(test.isNegative());
+ }
}
TEST(APFloatTest, Float8E5M2FNUZAdd) {
@@ -5590,7 +5615,8 @@ TEST(APFloatTest, Float8UnsignedZeroExhaustive) {
const double largest;
const double smallest;
} const exhaustiveTests[] = {{&APFloat::Float8E5M2FNUZ(), 57344., 0x1.0p-17},
- {&APFloat::Float8E4M3FNUZ(), 240., 0x1.0p-10}};
+ {&APFloat::Float8E4M3FNUZ(), 240., 0x1.0p-10},
+ {&APFloat::Float8E4M3B11FNUZ(), 30., 0x1.0p-13}};
for (const auto &testInfo : exhaustiveTests) {
const fltSemantics &sem = *testInfo.semantics;
SCOPED_TRACE("Semantics=" + std::to_string(APFloat::SemanticsToEnum(sem)));
@@ -5634,71 +5660,79 @@ TEST(APFloatTest, Float8UnsignedZeroExhaustive) {
}
TEST(APFloatTest, Float8E4M3FNUZNext) {
- APFloat test(APFloat::Float8E4M3FNUZ(), APFloat::uninitialized);
- APFloat expected(APFloat::Float8E4M3FNUZ(), APFloat::uninitialized);
-
- // 1. NextUp of largest bit pattern is nan
- test = APFloat::getLargest(APFloat::Float8E4M3FNUZ());
- expected = APFloat::getNaN(APFloat::Float8E4M3FNUZ());
- EXPECT_EQ(test.next(false), APFloat::opOK);
- EXPECT_FALSE(test.isInfinity());
- EXPECT_FALSE(test.isZero());
- EXPECT_TRUE(test.isNaN());
- EXPECT_TRUE(test.bitwiseIsEqual(expected));
+ for (APFloat::Semantics S :
+ {APFloat::S_Float8E4M3FNUZ, APFloat::S_Float8E4M3B11FNUZ}) {
+ const llvm::fltSemantics &Sem = APFloat::EnumToSemantics(S);
+ APFloat test(Sem, APFloat::uninitialized);
+ APFloat expected(Sem, APFloat::uninitialized);
+
+ // 1. NextUp of largest bit pattern is nan
+ test = APFloat::getLargest(Sem);
+ expected = APFloat::getNaN(Sem);
+ EXPECT_EQ(test.next(false), APFloat::opOK);
+ EXPECT_FALSE(test.isInfinity());
+ EXPECT_FALSE(test.isZero());
+ EXPECT_TRUE(test.isNaN());
+ EXPECT_TRUE(test.bitwiseIsEqual(expected));
- // 2. NextUp of smallest negative denormal is +0
- test = APFloat::getSmallest(APFloat::Float8E4M3FNUZ(), true);
- expected = APFloat::getZero(APFloat::Float8E4M3FNUZ(), false);
- EXPECT_EQ(test.next(false), APFloat::opOK);
- EXPECT_FALSE(test.isNegZero());
- EXPECT_TRUE(test.isPosZero());
- EXPECT_TRUE(test.bitwiseIsEqual(expected));
+ // 2. NextUp of smallest negative denormal is +0
+ test = APFloat::getSmallest(Sem, true);
+ expected = APFloat::getZero(Sem, false);
+ EXPECT_EQ(test.next(false), APFloat::opOK);
+ EXPECT_FALSE(test.isNegZero());
+ EXPECT_TRUE(test.isPosZero());
+ EXPECT_TRUE(test.bitwiseIsEqual(expected));
- // 3. nextDown of negative of largest value is NaN
- test = APFloat::getLargest(APFloat::Float8E4M3FNUZ(), true);
- expected = APFloat::getNaN(APFloat::Float8E4M3FNUZ());
- EXPECT_EQ(test.next(true), APFloat::opOK);
- EXPECT_FALSE(test.isInfinity());
- EXPECT_FALSE(test.isZero());
- EXPECT_TRUE(test.isNaN());
- EXPECT_TRUE(test.bitwiseIsEqual(expected));
+ // 3. nextDown of negative of largest value is NaN
+ test = APFloat::getLargest(Sem, true);
+ expected = APFloat::getNaN(Sem);
+ EXPECT_EQ(test.next(true), APFloat::opOK);
+ EXPECT_FALSE(test.isInfinity());
+ EXPECT_FALSE(test.isZero());
+ EXPECT_TRUE(test.isNaN());
+ EXPECT_TRUE(test.bitwiseIsEqual(expected));
- // 4. nextDown of +0 is smallest negative denormal
- test = APFloat::getZero(APFloat::Float8E4M3FNUZ(), false);
- expected = APFloat::getSmallest(APFloat::Float8E4M3FNUZ(), true);
- EXPECT_EQ(test.next(true), APFloat::opOK);
- EXPECT_FALSE(test.isZero());
- EXPECT_TRUE(test.isDenormal());
- EXPECT_TRUE(test.bitwiseIsEqual(expected));
+ // 4. nextDown of +0 is smallest negative denormal
+ test = APFloat::getZero(Sem, false);
+ expected = APFloat::getSmallest(Sem, true);
+ EXPECT_EQ(test.next(true), APFloat::opOK);
+ EXPECT_FALSE(test.isZero());
+ EXPECT_TRUE(test.isDenormal());
+ EXPECT_TRUE(test.bitwiseIsEqual(expected));
- // 5. nextUp of NaN is NaN
- test = APFloat::getNaN(APFloat::Float8E4M3FNUZ(), false);
- expected = APFloat::getNaN(APFloat::Float8E4M3FNUZ(), true);
- EXPECT_EQ(test.next(false), APFloat::opOK);
- EXPECT_TRUE(test.isNaN());
+ // 5. nextUp of NaN is NaN
+ test = APFloat::getNaN(Sem, false);
+ expected = APFloat::getNaN(Sem, true);
+ EXPECT_EQ(test.next(false), APFloat::opOK);
+ EXPECT_TRUE(test.isNaN());
- // 6. nextDown of NaN is NaN
- test = APFloat::getNaN(APFloat::Float8E4M3FNUZ(), false);
- expected = APFloat::getNaN(APFloat::Float8E4M3FNUZ(), true);
- EXPECT_EQ(test.next(true), APFloat::opOK);
- EXPECT_TRUE(test.isNaN());
+ // 6. nextDown of NaN is NaN
+ test = APFloat::getNaN(Sem, false);
+ expected = APFloat::getNaN(Sem, true);
+ EXPECT_EQ(test.next(true), APFloat::opOK);
+ EXPECT_TRUE(test.isNaN());
+ }
}
TEST(APFloatTest, Float8E4M3FNUZChangeSign) {
- APFloat test = APFloat(APFloat::Float8E4M3FNUZ(), "1.0");
- APFloat expected = APFloat(APFloat::Float8E4M3FNUZ(), "-1.0");
- test.changeSign();
- EXPECT_TRUE(test.bitwiseIsEqual(expected));
+ for (APFloat::Semantics S :
+ {APFloat::S_Float8E4M3FNUZ, APFloat::S_Float8E4M3B11FNUZ}) {
+ const llvm::fltSemantics &Sem = APFloat::EnumToSemantics(S);
+ APFloat test = APFloat(Sem, "1.0");
+ APFloat expected = APFloat(Sem, "-1.0");
+ test.changeSign();
+ EXPECT_TRUE(test.bitwiseIsEqual(expected));
- test = APFloat::getZero(APFloat::Float8E4M3FNUZ());
- expected = test;
- test.changeSign();
- EXPECT_TRUE(test.bitwiseIsEqual(expected));
+ test = APFloat::getZero(Sem);
+ expected = test;
+ test.changeSign();
+ EXPECT_TRUE(test.bitwiseIsEqual(expected));
- test = APFloat::getNaN(APFloat::Float8E4M3FNUZ());
- expected = test;
- test.changeSign();
- EXPECT_TRUE(test.bitwiseIsEqual(expected));
+ test = APFloat::getNaN(Sem);
+ expected = test;
+ test.changeSign();
+ EXPECT_TRUE(test.bitwiseIsEqual(expected));
+ }
}
TEST(APFloatTest, Float8E4M3FNUZFromString) {
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 8b855d8c39a4d..2b7606f3d9caf 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -95,6 +95,13 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx);
+/// Checks whether the given type is an f8E4M3B11FNUZ type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type);
+
+/// Creates an f8E4M3B11FNUZ type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx);
+
/// Checks whether the given type is a bf16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index f970d89dd410f..7197b1364bbce 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -64,6 +64,7 @@ class Builder {
FloatType getFloat8E4M3FNType();
FloatType getFloat8E5M2FNUZType();
FloatType getFloat8E4M3FNUZType();
+ FloatType getFloat8E4M3B11FNUZType();
FloatType getBF16Type();
FloatType getF16Type();
FloatType getF32Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 33995f34ee39b..baee29c554c3d 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -49,6 +49,7 @@ class FloatType : public Type {
static FloatType getFloat8E4M3FN(MLIRContext *ctx);
static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
+ static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);
@@ -376,9 +377,10 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
}
inline bool FloatType::classof(Type type) {
- return type.isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
- Float8E4M3FNUZType, BFloat16Type, Float16Type, Float32Type,
- Float64Type, Float80Type, Float128Type>();
+ return type
+ .isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type, Float16Type,
+ Float32Type, Float64Type, Float80Type, Float128Type>();
}
inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
@@ -397,6 +399,10 @@ inline FloatType FloatType::getFloat8E4M3FNUZ(MLIRContext *ctx) {
return Float8E4M3FNUZType::get(ctx);
}
+inline FloatType FloatType::getFloat8E4M3B11FNUZ(MLIRContext *ctx) {
+ return Float8E4M3B11FNUZType::get(ctx);
+}
+
inline FloatType FloatType::getBF16(MLIRContext *ctx) {
return BFloat16Type::get(ctx);
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4d6120ff9df88..3e41625c0b688 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -162,6 +162,28 @@ def Builtin_Float8E4M3FNUZ : Builtin_FloatType<"Float8E4M3FNUZ"> {
}];
}
+//===----------------------------------------------------------------------===//
+// Float8E4M3B11FNUZType
+
+def Builtin_Float8E4M3B11FNUZ : Builtin_FloatType<"Float8E4M3B11FNUZ"> {
+ let summary = "8-bit floating point with 3 bit mantissa";
+ let description = [{
+ An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
+ mantissa. This is not a standard type as defined by IEEE-754, but it follows
+ similar conventions, with the exception that there are no infinity values,
+ no negative zero, and only one NaN representation. This type has the
+ following characteristics:
+
+ * bit encoding: S1E4M3
+ * exponent bias: 11
+ * infinities: Not supported
+ * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s
+ * denormals when exponent is 0
+
+ Related to: https://dl.acm.org/doi/10.5555/3454287.3454728
+ }];
+}
+
//===----------------------------------------------------------------------===//
// BFloat16Type
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 527ccc05e090d..554f02675363a 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -490,6 +490,8 @@ def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
BuildableType<"$_builder.getFloat8E5M2Type()">;
def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
BuildableType<"$_builder.getFloat8E4M3FNUZType()">;
+def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ type">,
+ BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">;
def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index fc377bee3cefc..adefc2908b552 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -124,6 +124,7 @@ class Type {
bool isFloat8E4M3FN() const;
bool isFloat8E5M2FNUZ() const;
bool isFloat8E4M3FNUZ() const;
+ bool isFloat8E4M3B11FNUZ() const;
bool isBF16() const;
bool isF16() const;
bool isF32() const;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 0e666c792b9de..9a632e3570fb5 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -97,6 +97,7 @@ TOK_KEYWORD(f8E5M2)
TOK_KEYWORD(f8E4M3FN)
TOK_KEYWORD(f8E5M2FNUZ)
TOK_KEYWORD(f8E4M3FNUZ)
+TOK_KEYWORD(f8E4M3B11FNUZ)
TOK_KEYWORD(f128)
TOK_KEYWORD(false)
TOK_KEYWORD(floordiv)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 47078c1ba0472..737767ce9101b 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -35,6 +35,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
case Token::kw_f8E4M3FN:
case Token::kw_f8E5M2FNUZ:
case Token::kw_f8E4M3FNUZ:
+ case Token::kw_f8E4M3B11FNUZ:
case Token::kw_bf16:
case Token::kw_f16:
case Token::kw_f32:
@@ -303,6 +304,9 @@ Type Parser::parseNonFunctionType() {
case Token::kw_f8E4M3FNUZ:
consumeToken(Token::kw_f8E4M3FNUZ);
return builder.getFloat8E4M3FNUZType();
+ case Token::kw_f8E4M3B11FNUZ:
+ consumeToken(Token::kw_f8E4M3B11FNUZ);
+ return builder.getFloat8E4M3B11FNUZType();
case Token::kw_bf16:
consumeToken(Token::kw_bf16);
return builder.getBF16Type();
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 2166bab902a13..6d381b1c01701 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -157,6 +157,24 @@ class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
}
};
+/// Floating Point Type subclass - Float8E4M3B11FNUZ.
+class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
+ static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
+ return PyFloat8E4M3B11FNUZType(context->getRef(), t);
+ },
+ py::arg("context") = py::none(), "Create a float8_e4m3b11fnuz type.");
+ }
+};
+
/// Floating Point Type subclass - Float8E5M2FNUZ.
class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
public:
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index aea1221200af6..2468c05463f4a 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -100,6 +100,14 @@ MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx)));
}
+bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) {
+ return unwrap(type).isFloat8E4M3B11FNUZ();
+}
+
+MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
+ return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx)));
+}
+
bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
MlirType mlirBF16TypeGet(MlirContext ctx) {
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 75448955f3123..1820803f24258 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2428,6 +2428,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
.Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
.Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
.Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
+ .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
.Case<BFloat16Type>([&](Type) { os << "bf16"; })
.Case<Float16Type>([&](Type) { os << "f16"; })
.Case<Float32Type>([&](Type) { os << "f32"; })
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 8eab32b201a04..9203943123470 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -49,6 +49,10 @@ FloatType Builder::getFloat8E4M3FNUZType() {
return FloatType::getFloat8E4M3FNUZ(context);
}
+FloatType Builder::getFloat8E4M3B11FNUZType() {
+ return FloatType::getFloat8E4M3B11FNUZ(context);
+}
+
FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
FloatType Builder::getF16Type() { return FloatType::getF16(context); }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 6e6c6b9683c78..7e95ca137fdd1 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -89,7 +89,7 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
unsigned FloatType::getWidth() {
if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
- Float8E4M3FNUZType>())
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
return 8;
if (isa<Float16Type, BFloat16Type>())
return 16;
@@ -114,6 +114,8 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
return APFloat::Float8E5M2FNUZ();
if (isa<Float8E4M3FNUZType>())
return APFloat::Float8E4M3FNUZ();
+ if (isa<Float8E4M3B11FNUZType>())
+ return APFloat::Float8E4M3B11FNUZ();
if (isa<BFloat16Type>())
return APFloat::BFloat();
if (isa<Float16Type>())
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 29d49edf9efb9..daa4a6af63020 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -214,6 +214,7 @@ class MLIRContextImpl {
Float8E4M3FNType f8E4M3FNTy;
Float8E5M2FNUZType f8E5M2FNUZTy;
Float8E4M3FNUZType f8E4M3FNUZTy;
+ Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
BFloat16Type bf16Ty;
Float16Type f16Ty;
Float32Type f32Ty;
@@ -288,6 +289,7 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting)
impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
impl->f8E5M2FNUZTy = TypeUniquer::get<Float8E5M2FNUZType>(this);
impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
+ impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
impl->f16Ty = TypeUniquer::get<Float16Type>(this);
impl->f32Ty = TypeUniquer::get<Float32Type>(this);
@@ -892,6 +894,9 @@ Float8E5M2FNUZType Float8E5M2FNUZType::get(MLIRContext *context) {
Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) {
return context->getImpl().f8E4M3FNUZTy;
}
+Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
+ return context->getImpl().f8E4M3B11FNUZTy;
+}
BFloat16Type BFloat16Type::get(MLIRContext *context) {
return context->getImpl().bf16Ty;
}
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index e739786bd3990..d3d1d860d5d32 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -38,6 +38,7 @@ bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
bool Type::isFloat8E4M3FN() const { return isa<Float8E4M3FNType>(); }
bool Type::isFloat8E5M2FNUZ() const { return isa<Float8E5M2FNUZType>(); }
bool Type::isFloat8E4M3FNUZ() const { return isa<Float8E4M3FNUZType>(); }
+bool Type::isFloat8E4M3B11FNUZ() const { return isa<Float8E4M3B11FNUZType>(); }
bool Type::isBF16() const { return isa<BFloat16Type>(); }
bool Type::isF16() const { return isa<Float16Type>(); }
bool Type::isF32() const { return isa<Float32Type>(); }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 7d5ff23f60ab7..75b25bd8c1c9d 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -53,6 +53,7 @@ __all__ = [
"Float8E4M3FNType",
"Float8E5M2Type",
"Float8E4M3FNUZType",
+ "Float8E4M3B11FNUZType",
"Float8E5M2FNUZType",
"F16Type",
"F32Type",
@@ -602,6 +603,13 @@ class Float8E4M3FNUZType(Type):
@staticmethod
def isinstance(arg: Any) -> bool: ...
+class Float8E4M3B11FNUZType(Type):
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @staticmethod
+ def get(*args, **kwargs) -> Float8E4M3B11FNUZType: ...
+ @staticmethod
+ def isinstance(arg: Any) -> bool: ...
+
class Float8E5M2FNUZType(Type):
def __init__(self, cast_from_type: Type) -> None: ...
@staticmethod
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index de840f950f458..c296507868cbc 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -52,6 +52,10 @@ func.func @float_attrs_pass() {
// CHECK: float_attr = 2.000000e+00 : f8E4M3FNUZ
float_attr = 2. : f8E4M3FNUZ
} : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 2.000000e+00 : f8E4M3B11FNUZ
+ float_attr = 2. : f8E4M3B11FNUZ
+ } : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f16
float_attr = 2. : f16
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 594cc6620e396..e383a78f40b8a 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -202,6 +202,8 @@ def testFloatType():
print("float:", Float8E5M2FNUZType.get())
# CHECK: float: f8E4M3FNUZ
print("float:", Float8E4M3FNUZType.get())
+ # CHECK: float: f8E4M3B11FNUZ
+ print("float:", Float8E4M3B11FNUZType.get())
# CHECK: float: bf16
print("float:", BF16Type.get())
# CHECK: float: f16
diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index 908a734f6e30d..bfd76a7d0ca28 100644
--- a/mlir/utils/lldb-scripts/mlirDataFormatters.py
+++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py
@@ -54,6 +54,7 @@ def build_ptr_str_from_addr(addrValue: lldb.SBValue, type: lldb.SBType):
"mlir::Float8E4M3FNType": '"f8E4M3FN"',
"mlir::Float8E5M2FNUZType": '"f8E5M2FNUZ"',
"mlir::Float8E4M3FNUZType": '"f8E4M3FNUZ"',
+ "mlir::Float8E4M3B11FNUZType": '"f8E4M3B11FNUZ"',
"mlir::BFloat16Type": '"bf16"',
"mlir::Float16Type": '"f16"',
"mlir::Float32Type": '"f32"',
More information about the Mlir-commits
mailing list