[Mlir-commits] [mlir] 2dc68b5 - Add APFloat and MLIR type support	for fp8 (e5m2).
    Stella Laurenzo 
    llvmlistbot at llvm.org
       
    Sun Oct  2 17:25:01 PDT 2022
    
    
  
Author: Stella Laurenzo
Date: 2022-10-02T17:17:08-07:00
New Revision: 2dc68b5398258c7a0cf91f10192d058e787afcdf
URL: https://github.com/llvm/llvm-project/commit/2dc68b5398258c7a0cf91f10192d058e787afcdf
DIFF: https://github.com/llvm/llvm-project/commit/2dc68b5398258c7a0cf91f10192d058e787afcdf.diff
LOG: Add APFloat and MLIR type support for fp8 (e5m2).
This is a first step towards high level representation for fp8 types
that have been built in to hardware with near term roadmaps. Like the
BFLOAT16 type, the family of fp8 types are inspired by IEEE-754 binary
floating point formats but, due to the size limits, have been tweaked in
various ways in order to maximally use the range/precision in various
scenarios. The list of variants is small/finite and bounded by real
hardware.
This patch introduces the E5M2 FP8 format as proposed by Nvidia, ARM,
and Intel in the paper: https://arxiv.org/pdf/2209.05433.pdf
As the more conformant of the two implemented datatypes, we are plumbing
it through LLVM's APFloat type and MLIR's type system first as a
template. It will be followed by the range optimized E4M3 FP8 format
described in the paper. Since that format deviates further from the
IEEE-754 norms, it may require more debate and implementation
complexity.
Given that we see two parts of the FP8 implementation space represented
by these cases, we are recommending naming of:
* `F8M<N>` : For FP8 types that can be conceived of as following the
  same rules as FP16 but with a smaller number of mantissa/exponent
  bits. Including the number of mantissa bits in the type name is enough
  to fully specify the type. This naming scheme is used to represent
  the E5M2 type described in the paper.
* `F8M<N>F` : For FP8 types such as E4M3 which only support finite
  values.
The first of these (this patch) seems fairly non-controversial. The
second is previewed here to illustrate options for extending to the
other known variant (but can be discussed in detail in the patch
which implements it).
Many conversations about these types focus on the Machine-Learning
ecosystem where they are used to represent mixed-datatype computations
at a high level. At that level (which is why we also expose them in
MLIR), it is important to retain the actual type definition so that when
lowering to actual kernels or target specific code, the correct
promotions, casts and rescalings can be done as needed. We expect that
most LLVM backends will only experience these types as opaque `I8`
values that are applicable to some instructions.
MLIR does not make it particularly easy to add new floating point types
(i.e. the FloatType hierarchy is not open). Given the need to fully
model FloatTypes and make them interop with tooling, such types will
always be "heavy-weight" and it is not expected that a highly open type
system will be particularly helpful. There are also a bounded number of
floating point types in use for current and upcoming hardware, and we
can just implement them like this (perhaps looking for some cosmetic
ways to reduce the number of places that need to change). Creating a
more generic mechanism for extending floating point types seems like it
wouldn't be worth it and we should just deal with defining them one by
one on an as-needed basis when real hardware implements a new scheme.
Hopefully, with some additional production use and complete software
stacks, hardware makers will converge on a set of such types that is not
terribly divergent at the level that the compiler cares about.
(I cleaned up some old formatting and sorted some items for this case:
If we converge on landing this in some form, I will NFC commit format
only changes as a separate commit)
Differential Revision: https://reviews.llvm.org/D133823
Added: 
    
Modified: 
    llvm/include/llvm/ADT/APFloat.h
    llvm/lib/Support/APFloat.cpp
    llvm/unittests/ADT/APFloatTest.cpp
    mlir/include/mlir-c/BuiltinTypes.h
    mlir/include/mlir/IR/Builders.h
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/include/mlir/IR/Types.h
    mlir/lib/AsmParser/TokenKinds.def
    mlir/lib/AsmParser/TypeParser.cpp
    mlir/lib/CAPI/IR/BuiltinTypes.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/Builders.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/IR/Types.cpp
    mlir/test/IR/attribute.mlir
    mlir/test/lib/Dialect/Test/TestOps.td
Removed: 
    
################################################################################
diff  --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h
index dffb259b33c73..72eef0fb4763c 100644
--- a/llvm/include/llvm/ADT/APFloat.h
+++ b/llvm/include/llvm/ADT/APFloat.h
@@ -153,10 +153,13 @@ struct APFloatBase {
     S_BFloat,
     S_IEEEsingle,
     S_IEEEdouble,
-    S_x87DoubleExtended,
     S_IEEEquad,
     S_PPCDoubleDouble,
-    S_MaxSemantics = S_PPCDoubleDouble
+    // 8-bit floating point number following IEEE-754 conventions with bit
+    // layout S1E5M2 as described in https://arxiv.org/abs/2209.05433
+    S_Float8E5M2,
+    S_x87DoubleExtended,
+    S_MaxSemantics = S_x87DoubleExtended,
   };
 
   static const llvm::fltSemantics &EnumToSemantics(Semantics S);
@@ -168,6 +171,7 @@ struct APFloatBase {
   static const fltSemantics &IEEEdouble() LLVM_READNONE;
   static const fltSemantics &IEEEquad() LLVM_READNONE;
   static const fltSemantics &PPCDoubleDouble() LLVM_READNONE;
+  static const fltSemantics &Float8E5M2() LLVM_READNONE;
   static const fltSemantics &x87DoubleExtended() LLVM_READNONE;
 
   /// A Pseudo fltsemantic used to construct APFloats that cannot conflict with
@@ -552,6 +556,7 @@ class IEEEFloat final : public APFloatBase {
   APInt convertQuadrupleAPFloatToAPInt() const;
   APInt convertF80LongDoubleAPFloatToAPInt() const;
   APInt convertPPCDoubleDoubleAPFloatToAPInt() const;
+  APInt convertFloat8E5M2APFloatToAPInt() const;
   void initFromAPInt(const fltSemantics *Sem, const APInt &api);
   void initFromHalfAPInt(const APInt &api);
   void initFromBFloatAPInt(const APInt &api);
@@ -560,6 +565,7 @@ class IEEEFloat final : public APFloatBase {
   void initFromQuadrupleAPInt(const APInt &api);
   void initFromF80LongDoubleAPInt(const APInt &api);
   void initFromPPCDoubleDoubleAPInt(const APInt &api);
+  void initFromFloat8E5M2APInt(const APInt &api);
 
   void assign(const IEEEFloat &);
   void copySignificand(const IEEEFloat &);
diff  --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index 6f888e314f13a..68063bb4d4f4c 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -80,6 +80,7 @@ namespace llvm {
   static const fltSemantics semIEEEsingle = {127, -126, 24, 32};
   static const fltSemantics semIEEEdouble = {1023, -1022, 53, 64};
   static const fltSemantics semIEEEquad = {16383, -16382, 113, 128};
+  static const fltSemantics semFloat8E5M2 = {15, -14, 3, 8};
   static const fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80};
   static const fltSemantics semBogus = {0, 0, 0, 0};
 
@@ -131,12 +132,14 @@ namespace llvm {
       return IEEEsingle();
     case S_IEEEdouble:
       return IEEEdouble();
-    case S_x87DoubleExtended:
-      return x87DoubleExtended();
     case S_IEEEquad:
       return IEEEquad();
     case S_PPCDoubleDouble:
       return PPCDoubleDouble();
+    case S_Float8E5M2:
+      return Float8E5M2();
+    case S_x87DoubleExtended:
+      return x87DoubleExtended();
     }
     llvm_unreachable("Unrecognised floating semantics");
   }
@@ -151,12 +154,14 @@ namespace llvm {
       return S_IEEEsingle;
     else if (&Sem == &llvm::APFloat::IEEEdouble())
       return S_IEEEdouble;
-    else if (&Sem == &llvm::APFloat::x87DoubleExtended())
-      return S_x87DoubleExtended;
     else if (&Sem == &llvm::APFloat::IEEEquad())
       return S_IEEEquad;
     else if (&Sem == &llvm::APFloat::PPCDoubleDouble())
       return S_PPCDoubleDouble;
+    else if (&Sem == &llvm::APFloat::Float8E5M2())
+      return S_Float8E5M2;
+    else if (&Sem == &llvm::APFloat::x87DoubleExtended())
+      return S_x87DoubleExtended;
     else
       llvm_unreachable("Unknown floating semantics");
   }
@@ -173,18 +178,15 @@ namespace llvm {
   const fltSemantics &APFloatBase::IEEEdouble() {
     return semIEEEdouble;
   }
-  const fltSemantics &APFloatBase::IEEEquad() {
-    return semIEEEquad;
+  const fltSemantics &APFloatBase::IEEEquad() { return semIEEEquad; }
+  const fltSemantics &APFloatBase::PPCDoubleDouble() {
+    return semPPCDoubleDouble;
   }
+  const fltSemantics &APFloatBase::Float8E5M2() { return semFloat8E5M2; }
   const fltSemantics &APFloatBase::x87DoubleExtended() {
     return semX87DoubleExtended;
   }
-  const fltSemantics &APFloatBase::Bogus() {
-    return semBogus;
-  }
-  const fltSemantics &APFloatBase::PPCDoubleDouble() {
-    return semPPCDoubleDouble;
-  }
+  const fltSemantics &APFloatBase::Bogus() { return semBogus; }
 
   constexpr RoundingMode APFloatBase::rmNearestTiesToEven;
   constexpr RoundingMode APFloatBase::rmTowardPositive;
@@ -3353,6 +3355,33 @@ APInt IEEEFloat::convertHalfAPFloatToAPInt() const {
                     (mysignificand & 0x3ff)));
 }
 
+APInt IEEEFloat::convertFloat8E5M2APFloatToAPInt() const {
+  assert(semantics == (const llvm::fltSemantics *)&semFloat8E5M2);
+  assert(partCount() == 1);
+
+  uint32_t myexponent, mysignificand;
+
+  if (isFiniteNonZero()) {
+    myexponent = exponent + 15; // bias
+    mysignificand = (uint32_t)*significandParts();
+    if (myexponent == 1 && !(mysignificand & 0x4))
+      myexponent = 0; // denormal
+  } else if (category == fcZero) {
+    myexponent = 0;
+    mysignificand = 0;
+  } else if (category == fcInfinity) {
+    myexponent = 0x1f;
+    mysignificand = 0;
+  } else {
+    assert(category == fcNaN && "Unknown category!");
+    myexponent = 0x1f;
+    mysignificand = (uint32_t)*significandParts();
+  }
+
+  return APInt(8, (((sign & 1) << 7) | ((myexponent & 0x1f) << 2) |
+                   (mysignificand & 0x3)));
+}
+
 // This function creates an APInt that is just a bit map of the floating
 // point constant as it would appear in memory.  It is not a conversion,
 // and treating the result as a normal integer is unlikely to be useful.
@@ -3376,6 +3405,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
   if (semantics == (const llvm::fltSemantics *)&semPPCDoubleDoubleLegacy)
     return convertPPCDoubleDoubleAPFloatToAPInt();
 
+  if (semantics == (const llvm::fltSemantics *)&semFloat8E5M2)
+    return convertFloat8E5M2APFloatToAPInt();
+
   assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended &&
          "unknown format!");
   return convertF80LongDoubleAPFloatToAPInt();
@@ -3603,6 +3635,34 @@ void IEEEFloat::initFromHalfAPInt(const APInt &api) {
   }
 }
 
+void IEEEFloat::initFromFloat8E5M2APInt(const APInt &api) {
+  uint32_t i = (uint32_t)*api.getRawData();
+  uint32_t myexponent = (i >> 2) & 0x1f;
+  uint32_t mysignificand = i & 0x3;
+
+  initialize(&semFloat8E5M2);
+  assert(partCount() == 1);
+
+  sign = i >> 7;
+  if (myexponent == 0 && mysignificand == 0) {
+    makeZero(sign);
+  } else if (myexponent == 0x1f && mysignificand == 0) {
+    makeInf(sign);
+  } else if (myexponent == 0x1f && mysignificand != 0) {
+    category = fcNaN;
+    exponent = exponentNaN();
+    *significandParts() = mysignificand;
+  } else {
+    category = fcNormal;
+    exponent = myexponent - 15; // bias
+    *significandParts() = mysignificand;
+    if (myexponent == 0) // denormal
+      exponent = -14;
+    else
+      *significandParts() |= 0x4; // integer bit
+  }
+}
+
 /// Treat api as containing the bits of a floating point number.  Currently
 /// we infer the floating point type from the size of the APInt.  The
 /// isIEEE argument distinguishes between PPC128 and IEEE128 (not meaningful
@@ -3623,6 +3683,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
     return initFromQuadrupleAPInt(api);
   if (Sem == &semPPCDoubleDoubleLegacy)
     return initFromPPCDoubleDoubleAPInt(api);
+  if (Sem == &semFloat8E5M2)
+    return initFromFloat8E5M2APInt(api);
 
   llvm_unreachable(nullptr);
 }
diff  --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp
index 3caa09f364ff4..90e39bf07b930 100644
--- a/llvm/unittests/ADT/APFloatTest.cpp
+++ b/llvm/unittests/ADT/APFloatTest.cpp
@@ -1752,18 +1752,20 @@ TEST(APFloatTest, getZero) {
     const unsigned long long bitPattern[2];
     const unsigned bitPatternLength;
   } const GetZeroTest[] = {
-    { &APFloat::IEEEhalf(), false, {0, 0}, 1},
-    { &APFloat::IEEEhalf(), true, {0x8000ULL, 0}, 1},
-    { &APFloat::IEEEsingle(), false, {0, 0}, 1},
-    { &APFloat::IEEEsingle(), true, {0x80000000ULL, 0}, 1},
-    { &APFloat::IEEEdouble(), false, {0, 0}, 1},
-    { &APFloat::IEEEdouble(), true, {0x8000000000000000ULL, 0}, 1},
-    { &APFloat::IEEEquad(), false, {0, 0}, 2},
-    { &APFloat::IEEEquad(), true, {0, 0x8000000000000000ULL}, 2},
-    { &APFloat::PPCDoubleDouble(), false, {0, 0}, 2},
-    { &APFloat::PPCDoubleDouble(), true, {0x8000000000000000ULL, 0}, 2},
-    { &APFloat::x87DoubleExtended(), false, {0, 0}, 2},
-    { &APFloat::x87DoubleExtended(), true, {0, 0x8000ULL}, 2},
+      {&APFloat::IEEEhalf(), false, {0, 0}, 1},
+      {&APFloat::IEEEhalf(), true, {0x8000ULL, 0}, 1},
+      {&APFloat::IEEEsingle(), false, {0, 0}, 1},
+      {&APFloat::IEEEsingle(), true, {0x80000000ULL, 0}, 1},
+      {&APFloat::IEEEdouble(), false, {0, 0}, 1},
+      {&APFloat::IEEEdouble(), true, {0x8000000000000000ULL, 0}, 1},
+      {&APFloat::IEEEquad(), false, {0, 0}, 2},
+      {&APFloat::IEEEquad(), true, {0, 0x8000000000000000ULL}, 2},
+      {&APFloat::PPCDoubleDouble(), false, {0, 0}, 2},
+      {&APFloat::PPCDoubleDouble(), true, {0x8000000000000000ULL, 0}, 2},
+      {&APFloat::x87DoubleExtended(), false, {0, 0}, 2},
+      {&APFloat::x87DoubleExtended(), true, {0, 0x8000ULL}, 2},
+      {&APFloat::Float8E5M2(), false, {0, 0}, 1},
+      {&APFloat::Float8E5M2(), true, {0x80ULL, 0}, 1},
   };
   const unsigned NumGetZeroTests = 12;
   for (unsigned i = 0; i < NumGetZeroTests; ++i) {
@@ -4754,7 +4756,7 @@ TEST(APFloatTest, x87Next) {
   EXPECT_TRUE(ilogb(F) == -1);
 }
 
-TEST(APFloatTest, ToDouble) {
+TEST(APFloatTest, IEEEdoubleToDouble) {
   APFloat DPosZero(0.0);
   APFloat DPosZeroToDouble(DPosZero.convertToDouble());
   EXPECT_TRUE(DPosZeroToDouble.isPosZero());
@@ -4790,7 +4792,9 @@ TEST(APFloatTest, ToDouble) {
             DNegInf.convertToDouble());
   APFloat DQNaN = APFloat::getQNaN(APFloat::IEEEdouble());
   EXPECT_TRUE(std::isnan(DQNaN.convertToDouble()));
+}
 
+TEST(APFloatTest, IEEEsingleToDouble) {
   APFloat FPosZero(0.0F);
   APFloat FPosZeroToDouble(FPosZero.convertToDouble());
   EXPECT_TRUE(FPosZeroToDouble.isPosZero());
@@ -4825,7 +4829,9 @@ TEST(APFloatTest, ToDouble) {
             FNegInf.convertToDouble());
   APFloat FQNaN = APFloat::getQNaN(APFloat::IEEEsingle());
   EXPECT_TRUE(std::isnan(FQNaN.convertToDouble()));
+}
 
+TEST(APFloatTest, IEEEhalfToDouble) {
   APFloat HPosZero = APFloat::getZero(APFloat::IEEEhalf());
   APFloat HPosZeroToDouble(HPosZero.convertToDouble());
   EXPECT_TRUE(HPosZeroToDouble.isPosZero());
@@ -4867,7 +4873,9 @@ TEST(APFloatTest, ToDouble) {
   APFloat BNegZero = APFloat::getZero(APFloat::IEEEhalf(), true);
   APFloat BNegZeroToDouble(BNegZero.convertToDouble());
   EXPECT_TRUE(BNegZeroToDouble.isNegZero());
+}
 
+TEST(APFloatTest, BFloatToDouble) {
   APFloat BOne(APFloat::BFloat(), "1.0");
   EXPECT_EQ(1.0, BOne.convertToDouble());
   APFloat BPosLargest = APFloat::getLargest(APFloat::BFloat(), false);
@@ -4901,7 +4909,35 @@ TEST(APFloatTest, ToDouble) {
   EXPECT_TRUE(std::isnan(BQNaN.convertToDouble()));
 }
 
-TEST(APFloatTest, ToFloat) {
+TEST(APFloatTest, Float8E5M2ToDouble) {
+  APFloat One(APFloat::Float8E5M2(), "1.0");
+  EXPECT_EQ(1.0, One.convertToDouble());
+  APFloat Two(APFloat::Float8E5M2(), "2.0");
+  EXPECT_EQ(2.0, Two.convertToDouble());
+  APFloat PosLargest = APFloat::getLargest(APFloat::Float8E5M2(), false);
+  EXPECT_EQ(5.734400e+04, PosLargest.convertToDouble());
+  APFloat NegLargest = APFloat::getLargest(APFloat::Float8E5M2(), true);
+  EXPECT_EQ(-5.734400e+04, NegLargest.convertToDouble());
+  APFloat PosSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E5M2(), false);
+  EXPECT_EQ(0x1.p-14, PosSmallest.convertToDouble());
+  APFloat NegSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E5M2(), true);
+  EXPECT_EQ(-0x1.p-14, NegSmallest.convertToDouble());
+
+  APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E5M2(), false);
+  EXPECT_TRUE(SmallestDenorm.isDenormal());
+  EXPECT_EQ(0x1p-16, SmallestDenorm.convertToDouble());
+
+  APFloat PosInf = APFloat::getInf(APFloat::Float8E5M2());
+  EXPECT_EQ(std::numeric_limits<double>::infinity(), PosInf.convertToDouble());
+  APFloat NegInf = APFloat::getInf(APFloat::Float8E5M2(), true);
+  EXPECT_EQ(-std::numeric_limits<double>::infinity(), NegInf.convertToDouble());
+  APFloat QNaN = APFloat::getQNaN(APFloat::Float8E5M2());
+  EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
+}
+
+TEST(APFloatTest, IEEEsingleToFloat) {
   APFloat FPosZero(0.0F);
   APFloat FPosZeroToFloat(FPosZero.convertToFloat());
   EXPECT_TRUE(FPosZeroToFloat.isPosZero());
@@ -4935,7 +4971,9 @@ TEST(APFloatTest, ToFloat) {
   EXPECT_EQ(-std::numeric_limits<float>::infinity(), FNegInf.convertToFloat());
   APFloat FQNaN = APFloat::getQNaN(APFloat::IEEEsingle());
   EXPECT_TRUE(std::isnan(FQNaN.convertToFloat()));
+}
 
+TEST(APFloatTest, IEEEhalfToFloat) {
   APFloat HPosZero = APFloat::getZero(APFloat::IEEEhalf());
   APFloat HPosZeroToFloat(HPosZero.convertToFloat());
   EXPECT_TRUE(HPosZeroToFloat.isPosZero());
@@ -4969,7 +5007,9 @@ TEST(APFloatTest, ToFloat) {
   EXPECT_EQ(-std::numeric_limits<float>::infinity(), HNegInf.convertToFloat());
   APFloat HQNaN = APFloat::getQNaN(APFloat::IEEEhalf());
   EXPECT_TRUE(std::isnan(HQNaN.convertToFloat()));
+}
 
+TEST(APFloatTest, BFloatToFloat) {
   APFloat BPosZero = APFloat::getZero(APFloat::BFloat());
   APFloat BPosZeroToDouble(BPosZero.convertToFloat());
   EXPECT_TRUE(BPosZeroToDouble.isPosZero());
@@ -5008,4 +5048,41 @@ TEST(APFloatTest, ToFloat) {
   APFloat BQNaN = APFloat::getQNaN(APFloat::BFloat());
   EXPECT_TRUE(std::isnan(BQNaN.convertToFloat()));
 }
+
+TEST(APFloatTest, Float8E5M2ToFloat) {
+  APFloat PosZero = APFloat::getZero(APFloat::Float8E5M2());
+  APFloat PosZeroToFloat(PosZero.convertToFloat());
+  EXPECT_TRUE(PosZeroToFloat.isPosZero());
+  APFloat NegZero = APFloat::getZero(APFloat::Float8E5M2(), true);
+  APFloat NegZeroToFloat(NegZero.convertToFloat());
+  EXPECT_TRUE(NegZeroToFloat.isNegZero());
+
+  APFloat One(APFloat::Float8E5M2(), "1.0");
+  EXPECT_EQ(1.0F, One.convertToFloat());
+  APFloat Two(APFloat::Float8E5M2(), "2.0");
+  EXPECT_EQ(2.0F, Two.convertToFloat());
+
+  APFloat PosLargest = APFloat::getLargest(APFloat::Float8E5M2(), false);
+  EXPECT_EQ(5.734400e+04, PosLargest.convertToFloat());
+  APFloat NegLargest = APFloat::getLargest(APFloat::Float8E5M2(), true);
+  EXPECT_EQ(-5.734400e+04, NegLargest.convertToFloat());
+  APFloat PosSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E5M2(), false);
+  EXPECT_EQ(0x1.p-14, PosSmallest.convertToFloat());
+  APFloat NegSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E5M2(), true);
+  EXPECT_EQ(-0x1.p-14, NegSmallest.convertToFloat());
+
+  APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E5M2(), false);
+  EXPECT_TRUE(SmallestDenorm.isDenormal());
+  EXPECT_EQ(0x1.p-16, SmallestDenorm.convertToFloat());
+
+  APFloat PosInf = APFloat::getInf(APFloat::Float8E5M2());
+  EXPECT_EQ(std::numeric_limits<float>::infinity(), PosInf.convertToFloat());
+  APFloat NegInf = APFloat::getInf(APFloat::Float8E5M2(), true);
+  EXPECT_EQ(-std::numeric_limits<float>::infinity(), NegInf.convertToFloat());
+  APFloat QNaN = APFloat::getQNaN(APFloat::Float8E5M2());
+  EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
 }
+
+} // namespace
diff  --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index d1083f9323bf0..9bd3d510b2483 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -67,6 +67,13 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx);
 // Floating-point types.
 //===----------------------------------------------------------------------===//
 
+/// Checks whether the given type is an f8E5M2 type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type);
+
+/// Creates an f8E5M2 type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx);
+
 /// Checks whether the given type is a bf16 type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
 
diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index a10503c72735a..870c834ce2b0d 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -59,6 +59,7 @@ class Builder {
                        Attribute metadata = Attribute());
 
   // Types.
+  FloatType getFloat8E5M2Type();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getF32Type();
diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index cb282516438c9..1925127251558 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -46,6 +46,7 @@ class FloatType : public Type {
   static FloatType getF64(MLIRContext *ctx);
   static FloatType getF80(MLIRContext *ctx);
   static FloatType getF128(MLIRContext *ctx);
+  static FloatType getFloat8E5M2(MLIRContext *ctx);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(Type type);
@@ -373,8 +374,12 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
 }
 
 inline bool FloatType::classof(Type type) {
-  return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
-                  Float80Type, Float128Type>();
+  return type.isa<Float8E5M2Type, BFloat16Type, Float16Type, Float32Type,
+                  Float64Type, Float80Type, Float128Type>();
+}
+
+inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
+  return Float8E5M2Type::get(ctx);
 }
 
 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index aaadc1f6b9a75..50d8b3a0cb44a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -76,6 +76,28 @@ class Builtin_FloatType<string name>
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Float8E5M2Type
+
+def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2"> {
+  let summary = "8-bit floating point with 2 bit mantissa";
+  let description = [{
+    An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
+    mantissa. This is not a standard type as defined by IEEE-754, but it
+    follows similar conventions with the following characteristics:
+
+      * bit encoding: S1E5M2
+      * exponent bias: 15
+      * infinities: supported with exponent set to all 1s and mantissa 0s
+      * NaNs: supported with exponent bits set to all 1s and mantissa of 
+        (01, 10, or 11)
+      * denormals when exponent is 0
+
+    Described in: https://arxiv.org/abs/2209.05433
+  }];
+}
+
+
 //===----------------------------------------------------------------------===//
 // BFloat16Type
 
diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 7f657d8f81e5e..18660cb9e6f5e 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -123,6 +123,7 @@ class Type {
   // Convenience predicates.  This is only for floating point types,
   // derived types should use isa/dyn_cast.
   bool isIndex() const;
+  bool isFloat8E5M2() const;
   bool isBF16() const;
   bool isF16() const;
   bool isF32() const;
diff  --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 94d2fd3687fc3..02eba88f78b0d 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -93,6 +93,7 @@ TOK_KEYWORD(f16)
 TOK_KEYWORD(f32)
 TOK_KEYWORD(f64)
 TOK_KEYWORD(f80)
+TOK_KEYWORD(f8E5M2)
 TOK_KEYWORD(f128)
 TOK_KEYWORD(false)
 TOK_KEYWORD(floordiv)
diff  --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 16da006809d29..5ab7a89eac01e 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -30,6 +30,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
   case Token::kw_tuple:
   case Token::kw_vector:
   case Token::inttype:
+  case Token::kw_f8E5M2:
   case Token::kw_bf16:
   case Token::kw_f16:
   case Token::kw_f32:
@@ -286,6 +287,9 @@ Type Parser::parseNonFunctionType() {
   }
 
   // float-type
+  case Token::kw_f8E5M2:
+    consumeToken(Token::kw_f8E5M2);
+    return builder.getFloat8E5M2Type();
   case Token::kw_bf16:
     consumeToken(Token::kw_bf16);
     return builder.getBF16Type();
diff  --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index be44b76e8c615..ad9a5bc6640e2 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -68,6 +68,14 @@ MlirType mlirIndexTypeGet(MlirContext ctx) {
 // Floating-point types.
 //===----------------------------------------------------------------------===//
 
+bool mlirTypeIsAFloat8E5M2(MlirType type) {
+  return unwrap(type).isFloat8E5M2();
+}
+
+MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat8E5M2(unwrap(ctx)));
+}
+
 bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
 
 MlirType mlirBF16TypeGet(MlirContext ctx) {
diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index aaebef4341889..f51ea60de523f 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2179,6 +2179,7 @@ void AsmPrinter::Impl::printType(Type type) {
                            opaqueTy.getTypeData());
       })
       .Case<IndexType>([&](Type) { os << "index"; })
+      .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
       .Case<Float16Type>([&](Type) { os << "f16"; })
       .Case<Float32Type>([&](Type) { os << "f32"; })
diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 8bd3c72d185f1..053ffce1b1579 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -33,6 +33,10 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
 // Types.
 //===----------------------------------------------------------------------===//
 
+FloatType Builder::getFloat8E5M2Type() {
+  return FloatType::getFloat8E5M2(context);
+}
+
 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
 
 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index b1c2e3e92eea2..013686719e8e8 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -88,6 +88,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
 //===----------------------------------------------------------------------===//
 
 unsigned FloatType::getWidth() {
+  if (isa<Float8E5M2Type>())
+    return 8;
   if (isa<Float16Type, BFloat16Type>())
     return 16;
   if (isa<Float32Type>())
@@ -103,6 +105,8 @@ unsigned FloatType::getWidth() {
 
 /// Returns the floating semantics for the given type.
 const llvm::fltSemantics &FloatType::getFloatSemantics() {
+  if (isa<Float8E5M2Type>())
+    return APFloat::Float8E5M2();
   if (isa<BFloat16Type>())
     return APFloat::BFloat();
   if (isa<Float16Type>())
diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 3d41823eb6c16..7ddcc2ff11d92 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -206,6 +206,7 @@ class MLIRContextImpl {
   StorageUniquer typeUniquer;
 
   /// Cached Type Instances.
+  Float8E5M2Type f8E5M2Ty;
   BFloat16Type bf16Ty;
   Float16Type f16Ty;
   Float32Type f32Ty;
@@ -276,6 +277,7 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting)
 
   //// Types.
   /// Floating-point Types.
+  impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
   impl->f32Ty = TypeUniquer::get<Float32Type>(this);
@@ -840,6 +842,9 @@ AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) {
 /// This should not be used directly.
 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
 
+Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
+  return context->getImpl().f8E5M2Ty;
+}
 BFloat16Type BFloat16Type::get(MLIRContext *context) {
   return context->getImpl().bf16Ty;
 }
diff  --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index defe2dacfac29..b97388bf33f52 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -18,6 +18,7 @@ using namespace mlir::detail;
 
 MLIRContext *Type::getContext() const { return getDialect().getContext(); }
 
+bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
 bool Type::isBF16() const { return isa<BFloat16Type>(); }
 bool Type::isF16() const { return isa<Float16Type>(); }
 bool Type::isF32() const { return isa<Float32Type>(); }
diff  --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index f27e53e334978..1bdbdc25bdcac 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -31,6 +31,42 @@ func.func @any_attr_of_fail() {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// Test float attributes
+//===----------------------------------------------------------------------===//
+
+func.func @float_attrs_pass() {
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f8E5M2
+    float_attr = 2. : f8E5M2
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f16
+    float_attr = 2. : f16
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : bf16
+    float_attr = 2. : bf16
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f32
+    float_attr = 2. : f32
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f64
+    float_attr = 2. : f64
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f80
+    float_attr = 2. : f80
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f128
+    float_attr = 2. : f128
+  } : () -> ()
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // Test integer attributes
 //===----------------------------------------------------------------------===//
diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 0a9c6beb39786..3d6acb89ca68d 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -193,6 +193,14 @@ def TypeStringAttrWithTypeOp : TEST_Op<"string_attr_with_type"> {
   let assemblyFormat = "$attr attr-dict";
 }
 
+def FloatAttrOp : TEST_Op<"float_attrs"> {
+  // TODO: Clean up the OpBase float type and attribute selectors so they
+  // can express all of the types.
+  let arguments = (ins
+    AnyAttr:$float_attr
+  );
+}
+
 def I32Case5:  I32EnumAttrCase<"case5", 5>;
 def I32Case10: I32EnumAttrCase<"case10", 10>;
 
        
    
    
More information about the Mlir-commits
mailing list