[llvm] e68c7a9 - Revert "Add APFloat and MLIR type support for fp8 (e5m2)."

Vitaly Buka via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 2 21:22:57 PDT 2022


Author: Vitaly Buka
Date: 2022-10-02T21:22:44-07:00
New Revision: e68c7a99176d89c861837bca48ab385e900fd0bc

URL: https://github.com/llvm/llvm-project/commit/e68c7a99176d89c861837bca48ab385e900fd0bc
DIFF: https://github.com/llvm/llvm-project/commit/e68c7a99176d89c861837bca48ab385e900fd0bc.diff

LOG: Revert "Add APFloat and MLIR type support for fp8 (e5m2)."

Breaks bots https://lab.llvm.org/buildbot/#/builders/37/builds/17086

This reverts commit 2dc68b5398258c7a0cf91f10192d058e787afcdf.

Added: 
    

Modified: 
    llvm/include/llvm/ADT/APFloat.h
    llvm/lib/Support/APFloat.cpp
    llvm/unittests/ADT/APFloatTest.cpp
    mlir/include/mlir-c/BuiltinTypes.h
    mlir/include/mlir/IR/Builders.h
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/include/mlir/IR/Types.h
    mlir/lib/AsmParser/TokenKinds.def
    mlir/lib/AsmParser/TypeParser.cpp
    mlir/lib/CAPI/IR/BuiltinTypes.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/Builders.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/IR/Types.cpp
    mlir/test/IR/attribute.mlir
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h
index 72eef0fb4763c..dffb259b33c73 100644
--- a/llvm/include/llvm/ADT/APFloat.h
+++ b/llvm/include/llvm/ADT/APFloat.h
@@ -153,13 +153,10 @@ struct APFloatBase {
     S_BFloat,
     S_IEEEsingle,
     S_IEEEdouble,
+    S_x87DoubleExtended,
     S_IEEEquad,
     S_PPCDoubleDouble,
-    // 8-bit floating point number following IEEE-754 conventions with bit
-    // layout S1E5M2 as described in https://arxiv.org/abs/2209.05433
-    S_Float8E5M2,
-    S_x87DoubleExtended,
-    S_MaxSemantics = S_x87DoubleExtended,
+    S_MaxSemantics = S_PPCDoubleDouble
   };
 
   static const llvm::fltSemantics &EnumToSemantics(Semantics S);
@@ -171,7 +168,6 @@ struct APFloatBase {
   static const fltSemantics &IEEEdouble() LLVM_READNONE;
   static const fltSemantics &IEEEquad() LLVM_READNONE;
   static const fltSemantics &PPCDoubleDouble() LLVM_READNONE;
-  static const fltSemantics &Float8E5M2() LLVM_READNONE;
   static const fltSemantics &x87DoubleExtended() LLVM_READNONE;
 
   /// A Pseudo fltsemantic used to construct APFloats that cannot conflict with
@@ -556,7 +552,6 @@ class IEEEFloat final : public APFloatBase {
   APInt convertQuadrupleAPFloatToAPInt() const;
   APInt convertF80LongDoubleAPFloatToAPInt() const;
   APInt convertPPCDoubleDoubleAPFloatToAPInt() const;
-  APInt convertFloat8E5M2APFloatToAPInt() const;
   void initFromAPInt(const fltSemantics *Sem, const APInt &api);
   void initFromHalfAPInt(const APInt &api);
   void initFromBFloatAPInt(const APInt &api);
@@ -565,7 +560,6 @@ class IEEEFloat final : public APFloatBase {
   void initFromQuadrupleAPInt(const APInt &api);
   void initFromF80LongDoubleAPInt(const APInt &api);
   void initFromPPCDoubleDoubleAPInt(const APInt &api);
-  void initFromFloat8E5M2APInt(const APInt &api);
 
   void assign(const IEEEFloat &);
   void copySignificand(const IEEEFloat &);

diff  --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index 68063bb4d4f4c..6f888e314f13a 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -80,7 +80,6 @@ namespace llvm {
   static const fltSemantics semIEEEsingle = {127, -126, 24, 32};
   static const fltSemantics semIEEEdouble = {1023, -1022, 53, 64};
   static const fltSemantics semIEEEquad = {16383, -16382, 113, 128};
-  static const fltSemantics semFloat8E5M2 = {15, -14, 3, 8};
   static const fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80};
   static const fltSemantics semBogus = {0, 0, 0, 0};
 
@@ -132,14 +131,12 @@ namespace llvm {
       return IEEEsingle();
     case S_IEEEdouble:
       return IEEEdouble();
+    case S_x87DoubleExtended:
+      return x87DoubleExtended();
     case S_IEEEquad:
       return IEEEquad();
     case S_PPCDoubleDouble:
       return PPCDoubleDouble();
-    case S_Float8E5M2:
-      return Float8E5M2();
-    case S_x87DoubleExtended:
-      return x87DoubleExtended();
     }
     llvm_unreachable("Unrecognised floating semantics");
   }
@@ -154,14 +151,12 @@ namespace llvm {
       return S_IEEEsingle;
     else if (&Sem == &llvm::APFloat::IEEEdouble())
       return S_IEEEdouble;
+    else if (&Sem == &llvm::APFloat::x87DoubleExtended())
+      return S_x87DoubleExtended;
     else if (&Sem == &llvm::APFloat::IEEEquad())
       return S_IEEEquad;
     else if (&Sem == &llvm::APFloat::PPCDoubleDouble())
       return S_PPCDoubleDouble;
-    else if (&Sem == &llvm::APFloat::Float8E5M2())
-      return S_Float8E5M2;
-    else if (&Sem == &llvm::APFloat::x87DoubleExtended())
-      return S_x87DoubleExtended;
     else
       llvm_unreachable("Unknown floating semantics");
   }
@@ -178,15 +173,18 @@ namespace llvm {
   const fltSemantics &APFloatBase::IEEEdouble() {
     return semIEEEdouble;
   }
-  const fltSemantics &APFloatBase::IEEEquad() { return semIEEEquad; }
-  const fltSemantics &APFloatBase::PPCDoubleDouble() {
-    return semPPCDoubleDouble;
+  const fltSemantics &APFloatBase::IEEEquad() {
+    return semIEEEquad;
   }
-  const fltSemantics &APFloatBase::Float8E5M2() { return semFloat8E5M2; }
   const fltSemantics &APFloatBase::x87DoubleExtended() {
     return semX87DoubleExtended;
   }
-  const fltSemantics &APFloatBase::Bogus() { return semBogus; }
+  const fltSemantics &APFloatBase::Bogus() {
+    return semBogus;
+  }
+  const fltSemantics &APFloatBase::PPCDoubleDouble() {
+    return semPPCDoubleDouble;
+  }
 
   constexpr RoundingMode APFloatBase::rmNearestTiesToEven;
   constexpr RoundingMode APFloatBase::rmTowardPositive;
@@ -3355,33 +3353,6 @@ APInt IEEEFloat::convertHalfAPFloatToAPInt() const {
                     (mysignificand & 0x3ff)));
 }
 
-APInt IEEEFloat::convertFloat8E5M2APFloatToAPInt() const {
-  assert(semantics == (const llvm::fltSemantics *)&semFloat8E5M2);
-  assert(partCount() == 1);
-
-  uint32_t myexponent, mysignificand;
-
-  if (isFiniteNonZero()) {
-    myexponent = exponent + 15; // bias
-    mysignificand = (uint32_t)*significandParts();
-    if (myexponent == 1 && !(mysignificand & 0x4))
-      myexponent = 0; // denormal
-  } else if (category == fcZero) {
-    myexponent = 0;
-    mysignificand = 0;
-  } else if (category == fcInfinity) {
-    myexponent = 0x1f;
-    mysignificand = 0;
-  } else {
-    assert(category == fcNaN && "Unknown category!");
-    myexponent = 0x1f;
-    mysignificand = (uint32_t)*significandParts();
-  }
-
-  return APInt(8, (((sign & 1) << 7) | ((myexponent & 0x1f) << 2) |
-                   (mysignificand & 0x3)));
-}
-
 // This function creates an APInt that is just a bit map of the floating
 // point constant as it would appear in memory.  It is not a conversion,
 // and treating the result as a normal integer is unlikely to be useful.
@@ -3405,9 +3376,6 @@ APInt IEEEFloat::bitcastToAPInt() const {
   if (semantics == (const llvm::fltSemantics *)&semPPCDoubleDoubleLegacy)
     return convertPPCDoubleDoubleAPFloatToAPInt();
 
-  if (semantics == (const llvm::fltSemantics *)&semFloat8E5M2)
-    return convertFloat8E5M2APFloatToAPInt();
-
   assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended &&
          "unknown format!");
   return convertF80LongDoubleAPFloatToAPInt();
@@ -3635,34 +3603,6 @@ void IEEEFloat::initFromHalfAPInt(const APInt &api) {
   }
 }
 
-void IEEEFloat::initFromFloat8E5M2APInt(const APInt &api) {
-  uint32_t i = (uint32_t)*api.getRawData();
-  uint32_t myexponent = (i >> 2) & 0x1f;
-  uint32_t mysignificand = i & 0x3;
-
-  initialize(&semFloat8E5M2);
-  assert(partCount() == 1);
-
-  sign = i >> 7;
-  if (myexponent == 0 && mysignificand == 0) {
-    makeZero(sign);
-  } else if (myexponent == 0x1f && mysignificand == 0) {
-    makeInf(sign);
-  } else if (myexponent == 0x1f && mysignificand != 0) {
-    category = fcNaN;
-    exponent = exponentNaN();
-    *significandParts() = mysignificand;
-  } else {
-    category = fcNormal;
-    exponent = myexponent - 15; // bias
-    *significandParts() = mysignificand;
-    if (myexponent == 0) // denormal
-      exponent = -14;
-    else
-      *significandParts() |= 0x4; // integer bit
-  }
-}
-
 /// Treat api as containing the bits of a floating point number.  Currently
 /// we infer the floating point type from the size of the APInt.  The
 /// isIEEE argument distinguishes between PPC128 and IEEE128 (not meaningful
@@ -3683,8 +3623,6 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
     return initFromQuadrupleAPInt(api);
   if (Sem == &semPPCDoubleDoubleLegacy)
     return initFromPPCDoubleDoubleAPInt(api);
-  if (Sem == &semFloat8E5M2)
-    return initFromFloat8E5M2APInt(api);
 
   llvm_unreachable(nullptr);
 }

diff  --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp
index 90e39bf07b930..3caa09f364ff4 100644
--- a/llvm/unittests/ADT/APFloatTest.cpp
+++ b/llvm/unittests/ADT/APFloatTest.cpp
@@ -1752,20 +1752,18 @@ TEST(APFloatTest, getZero) {
     const unsigned long long bitPattern[2];
     const unsigned bitPatternLength;
   } const GetZeroTest[] = {
-      {&APFloat::IEEEhalf(), false, {0, 0}, 1},
-      {&APFloat::IEEEhalf(), true, {0x8000ULL, 0}, 1},
-      {&APFloat::IEEEsingle(), false, {0, 0}, 1},
-      {&APFloat::IEEEsingle(), true, {0x80000000ULL, 0}, 1},
-      {&APFloat::IEEEdouble(), false, {0, 0}, 1},
-      {&APFloat::IEEEdouble(), true, {0x8000000000000000ULL, 0}, 1},
-      {&APFloat::IEEEquad(), false, {0, 0}, 2},
-      {&APFloat::IEEEquad(), true, {0, 0x8000000000000000ULL}, 2},
-      {&APFloat::PPCDoubleDouble(), false, {0, 0}, 2},
-      {&APFloat::PPCDoubleDouble(), true, {0x8000000000000000ULL, 0}, 2},
-      {&APFloat::x87DoubleExtended(), false, {0, 0}, 2},
-      {&APFloat::x87DoubleExtended(), true, {0, 0x8000ULL}, 2},
-      {&APFloat::Float8E5M2(), false, {0, 0}, 1},
-      {&APFloat::Float8E5M2(), true, {0x80ULL, 0}, 1},
+    { &APFloat::IEEEhalf(), false, {0, 0}, 1},
+    { &APFloat::IEEEhalf(), true, {0x8000ULL, 0}, 1},
+    { &APFloat::IEEEsingle(), false, {0, 0}, 1},
+    { &APFloat::IEEEsingle(), true, {0x80000000ULL, 0}, 1},
+    { &APFloat::IEEEdouble(), false, {0, 0}, 1},
+    { &APFloat::IEEEdouble(), true, {0x8000000000000000ULL, 0}, 1},
+    { &APFloat::IEEEquad(), false, {0, 0}, 2},
+    { &APFloat::IEEEquad(), true, {0, 0x8000000000000000ULL}, 2},
+    { &APFloat::PPCDoubleDouble(), false, {0, 0}, 2},
+    { &APFloat::PPCDoubleDouble(), true, {0x8000000000000000ULL, 0}, 2},
+    { &APFloat::x87DoubleExtended(), false, {0, 0}, 2},
+    { &APFloat::x87DoubleExtended(), true, {0, 0x8000ULL}, 2},
   };
   const unsigned NumGetZeroTests = 12;
   for (unsigned i = 0; i < NumGetZeroTests; ++i) {
@@ -4756,7 +4754,7 @@ TEST(APFloatTest, x87Next) {
   EXPECT_TRUE(ilogb(F) == -1);
 }
 
-TEST(APFloatTest, IEEEdoubleToDouble) {
+TEST(APFloatTest, ToDouble) {
   APFloat DPosZero(0.0);
   APFloat DPosZeroToDouble(DPosZero.convertToDouble());
   EXPECT_TRUE(DPosZeroToDouble.isPosZero());
@@ -4792,9 +4790,7 @@ TEST(APFloatTest, IEEEdoubleToDouble) {
             DNegInf.convertToDouble());
   APFloat DQNaN = APFloat::getQNaN(APFloat::IEEEdouble());
   EXPECT_TRUE(std::isnan(DQNaN.convertToDouble()));
-}
 
-TEST(APFloatTest, IEEEsingleToDouble) {
   APFloat FPosZero(0.0F);
   APFloat FPosZeroToDouble(FPosZero.convertToDouble());
   EXPECT_TRUE(FPosZeroToDouble.isPosZero());
@@ -4829,9 +4825,7 @@ TEST(APFloatTest, IEEEsingleToDouble) {
             FNegInf.convertToDouble());
   APFloat FQNaN = APFloat::getQNaN(APFloat::IEEEsingle());
   EXPECT_TRUE(std::isnan(FQNaN.convertToDouble()));
-}
 
-TEST(APFloatTest, IEEEhalfToDouble) {
   APFloat HPosZero = APFloat::getZero(APFloat::IEEEhalf());
   APFloat HPosZeroToDouble(HPosZero.convertToDouble());
   EXPECT_TRUE(HPosZeroToDouble.isPosZero());
@@ -4873,9 +4867,7 @@ TEST(APFloatTest, IEEEhalfToDouble) {
   APFloat BNegZero = APFloat::getZero(APFloat::IEEEhalf(), true);
   APFloat BNegZeroToDouble(BNegZero.convertToDouble());
   EXPECT_TRUE(BNegZeroToDouble.isNegZero());
-}
 
-TEST(APFloatTest, BFloatToDouble) {
   APFloat BOne(APFloat::BFloat(), "1.0");
   EXPECT_EQ(1.0, BOne.convertToDouble());
   APFloat BPosLargest = APFloat::getLargest(APFloat::BFloat(), false);
@@ -4909,35 +4901,7 @@ TEST(APFloatTest, BFloatToDouble) {
   EXPECT_TRUE(std::isnan(BQNaN.convertToDouble()));
 }
 
-TEST(APFloatTest, Float8E5M2ToDouble) {
-  APFloat One(APFloat::Float8E5M2(), "1.0");
-  EXPECT_EQ(1.0, One.convertToDouble());
-  APFloat Two(APFloat::Float8E5M2(), "2.0");
-  EXPECT_EQ(2.0, Two.convertToDouble());
-  APFloat PosLargest = APFloat::getLargest(APFloat::Float8E5M2(), false);
-  EXPECT_EQ(5.734400e+04, PosLargest.convertToDouble());
-  APFloat NegLargest = APFloat::getLargest(APFloat::Float8E5M2(), true);
-  EXPECT_EQ(-5.734400e+04, NegLargest.convertToDouble());
-  APFloat PosSmallest =
-      APFloat::getSmallestNormalized(APFloat::Float8E5M2(), false);
-  EXPECT_EQ(0x1.p-14, PosSmallest.convertToDouble());
-  APFloat NegSmallest =
-      APFloat::getSmallestNormalized(APFloat::Float8E5M2(), true);
-  EXPECT_EQ(-0x1.p-14, NegSmallest.convertToDouble());
-
-  APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E5M2(), false);
-  EXPECT_TRUE(SmallestDenorm.isDenormal());
-  EXPECT_EQ(0x1p-16, SmallestDenorm.convertToDouble());
-
-  APFloat PosInf = APFloat::getInf(APFloat::Float8E5M2());
-  EXPECT_EQ(std::numeric_limits<double>::infinity(), PosInf.convertToDouble());
-  APFloat NegInf = APFloat::getInf(APFloat::Float8E5M2(), true);
-  EXPECT_EQ(-std::numeric_limits<double>::infinity(), NegInf.convertToDouble());
-  APFloat QNaN = APFloat::getQNaN(APFloat::Float8E5M2());
-  EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
-}
-
-TEST(APFloatTest, IEEEsingleToFloat) {
+TEST(APFloatTest, ToFloat) {
   APFloat FPosZero(0.0F);
   APFloat FPosZeroToFloat(FPosZero.convertToFloat());
   EXPECT_TRUE(FPosZeroToFloat.isPosZero());
@@ -4971,9 +4935,7 @@ TEST(APFloatTest, IEEEsingleToFloat) {
   EXPECT_EQ(-std::numeric_limits<float>::infinity(), FNegInf.convertToFloat());
   APFloat FQNaN = APFloat::getQNaN(APFloat::IEEEsingle());
   EXPECT_TRUE(std::isnan(FQNaN.convertToFloat()));
-}
 
-TEST(APFloatTest, IEEEhalfToFloat) {
   APFloat HPosZero = APFloat::getZero(APFloat::IEEEhalf());
   APFloat HPosZeroToFloat(HPosZero.convertToFloat());
   EXPECT_TRUE(HPosZeroToFloat.isPosZero());
@@ -5007,9 +4969,7 @@ TEST(APFloatTest, IEEEhalfToFloat) {
   EXPECT_EQ(-std::numeric_limits<float>::infinity(), HNegInf.convertToFloat());
   APFloat HQNaN = APFloat::getQNaN(APFloat::IEEEhalf());
   EXPECT_TRUE(std::isnan(HQNaN.convertToFloat()));
-}
 
-TEST(APFloatTest, BFloatToFloat) {
   APFloat BPosZero = APFloat::getZero(APFloat::BFloat());
   APFloat BPosZeroToDouble(BPosZero.convertToFloat());
   EXPECT_TRUE(BPosZeroToDouble.isPosZero());
@@ -5048,41 +5008,4 @@ TEST(APFloatTest, BFloatToFloat) {
   APFloat BQNaN = APFloat::getQNaN(APFloat::BFloat());
   EXPECT_TRUE(std::isnan(BQNaN.convertToFloat()));
 }
-
-TEST(APFloatTest, Float8E5M2ToFloat) {
-  APFloat PosZero = APFloat::getZero(APFloat::Float8E5M2());
-  APFloat PosZeroToFloat(PosZero.convertToFloat());
-  EXPECT_TRUE(PosZeroToFloat.isPosZero());
-  APFloat NegZero = APFloat::getZero(APFloat::Float8E5M2(), true);
-  APFloat NegZeroToFloat(NegZero.convertToFloat());
-  EXPECT_TRUE(NegZeroToFloat.isNegZero());
-
-  APFloat One(APFloat::Float8E5M2(), "1.0");
-  EXPECT_EQ(1.0F, One.convertToFloat());
-  APFloat Two(APFloat::Float8E5M2(), "2.0");
-  EXPECT_EQ(2.0F, Two.convertToFloat());
-
-  APFloat PosLargest = APFloat::getLargest(APFloat::Float8E5M2(), false);
-  EXPECT_EQ(5.734400e+04, PosLargest.convertToFloat());
-  APFloat NegLargest = APFloat::getLargest(APFloat::Float8E5M2(), true);
-  EXPECT_EQ(-5.734400e+04, NegLargest.convertToFloat());
-  APFloat PosSmallest =
-      APFloat::getSmallestNormalized(APFloat::Float8E5M2(), false);
-  EXPECT_EQ(0x1.p-14, PosSmallest.convertToFloat());
-  APFloat NegSmallest =
-      APFloat::getSmallestNormalized(APFloat::Float8E5M2(), true);
-  EXPECT_EQ(-0x1.p-14, NegSmallest.convertToFloat());
-
-  APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E5M2(), false);
-  EXPECT_TRUE(SmallestDenorm.isDenormal());
-  EXPECT_EQ(0x1.p-16, SmallestDenorm.convertToFloat());
-
-  APFloat PosInf = APFloat::getInf(APFloat::Float8E5M2());
-  EXPECT_EQ(std::numeric_limits<float>::infinity(), PosInf.convertToFloat());
-  APFloat NegInf = APFloat::getInf(APFloat::Float8E5M2(), true);
-  EXPECT_EQ(-std::numeric_limits<float>::infinity(), NegInf.convertToFloat());
-  APFloat QNaN = APFloat::getQNaN(APFloat::Float8E5M2());
-  EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
 }
-
-} // namespace

diff  --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 9bd3d510b2483..d1083f9323bf0 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -67,13 +67,6 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx);
 // Floating-point types.
 //===----------------------------------------------------------------------===//
 
-/// Checks whether the given type is an f8E5M2 type.
-MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type);
-
-/// Creates an f8E5M2 type in the given context. The type is owned by the
-/// context.
-MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx);
-
 /// Checks whether the given type is a bf16 type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
 

diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 870c834ce2b0d..a10503c72735a 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -59,7 +59,6 @@ class Builder {
                        Attribute metadata = Attribute());
 
   // Types.
-  FloatType getFloat8E5M2Type();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getF32Type();

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 1925127251558..cb282516438c9 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -46,7 +46,6 @@ class FloatType : public Type {
   static FloatType getF64(MLIRContext *ctx);
   static FloatType getF80(MLIRContext *ctx);
   static FloatType getF128(MLIRContext *ctx);
-  static FloatType getFloat8E5M2(MLIRContext *ctx);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(Type type);
@@ -374,12 +373,8 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
 }
 
 inline bool FloatType::classof(Type type) {
-  return type.isa<Float8E5M2Type, BFloat16Type, Float16Type, Float32Type,
-                  Float64Type, Float80Type, Float128Type>();
-}
-
-inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
-  return Float8E5M2Type::get(ctx);
+  return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
+                  Float80Type, Float128Type>();
 }
 
 inline FloatType FloatType::getBF16(MLIRContext *ctx) {

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 50d8b3a0cb44a..aaadc1f6b9a75 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -76,28 +76,6 @@ class Builtin_FloatType<string name>
   }];
 }
 
-//===----------------------------------------------------------------------===//
-// Float8E5M2Type
-
-def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2"> {
-  let summary = "8-bit floating point with 2 bit mantissa";
-  let description = [{
-    An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
-    mantissa. This is not a standard type as defined by IEEE-754, but it
-    follows similar conventions with the following characteristics:
-
-      * bit encoding: S1E5M2
-      * exponent bias: 15
-      * infinities: supported with exponent set to all 1s and mantissa 0s
-      * NaNs: supported with exponent bits set to all 1s and mantissa of 
-        (01, 10, or 11)
-      * denormals when exponent is 0
-
-    Described in: https://arxiv.org/abs/2209.05433
-  }];
-}
-
-
 //===----------------------------------------------------------------------===//
 // BFloat16Type
 

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 18660cb9e6f5e..7f657d8f81e5e 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -123,7 +123,6 @@ class Type {
   // Convenience predicates.  This is only for floating point types,
   // derived types should use isa/dyn_cast.
   bool isIndex() const;
-  bool isFloat8E5M2() const;
   bool isBF16() const;
   bool isF16() const;
   bool isF32() const;

diff  --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 02eba88f78b0d..94d2fd3687fc3 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -93,7 +93,6 @@ TOK_KEYWORD(f16)
 TOK_KEYWORD(f32)
 TOK_KEYWORD(f64)
 TOK_KEYWORD(f80)
-TOK_KEYWORD(f8E5M2)
 TOK_KEYWORD(f128)
 TOK_KEYWORD(false)
 TOK_KEYWORD(floordiv)

diff  --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 5ab7a89eac01e..16da006809d29 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -30,7 +30,6 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
   case Token::kw_tuple:
   case Token::kw_vector:
   case Token::inttype:
-  case Token::kw_f8E5M2:
   case Token::kw_bf16:
   case Token::kw_f16:
   case Token::kw_f32:
@@ -287,9 +286,6 @@ Type Parser::parseNonFunctionType() {
   }
 
   // float-type
-  case Token::kw_f8E5M2:
-    consumeToken(Token::kw_f8E5M2);
-    return builder.getFloat8E5M2Type();
   case Token::kw_bf16:
     consumeToken(Token::kw_bf16);
     return builder.getBF16Type();

diff  --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index ad9a5bc6640e2..be44b76e8c615 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -68,14 +68,6 @@ MlirType mlirIndexTypeGet(MlirContext ctx) {
 // Floating-point types.
 //===----------------------------------------------------------------------===//
 
-bool mlirTypeIsAFloat8E5M2(MlirType type) {
-  return unwrap(type).isFloat8E5M2();
-}
-
-MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
-  return wrap(FloatType::getFloat8E5M2(unwrap(ctx)));
-}
-
 bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
 
 MlirType mlirBF16TypeGet(MlirContext ctx) {

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index f51ea60de523f..aaebef4341889 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2179,7 +2179,6 @@ void AsmPrinter::Impl::printType(Type type) {
                            opaqueTy.getTypeData());
       })
       .Case<IndexType>([&](Type) { os << "index"; })
-      .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
       .Case<Float16Type>([&](Type) { os << "f16"; })
       .Case<Float32Type>([&](Type) { os << "f32"; })

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 053ffce1b1579..8bd3c72d185f1 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -33,10 +33,6 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
 // Types.
 //===----------------------------------------------------------------------===//
 
-FloatType Builder::getFloat8E5M2Type() {
-  return FloatType::getFloat8E5M2(context);
-}
-
 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
 
 FloatType Builder::getF16Type() { return FloatType::getF16(context); }

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 013686719e8e8..b1c2e3e92eea2 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -88,8 +88,6 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
 //===----------------------------------------------------------------------===//
 
 unsigned FloatType::getWidth() {
-  if (isa<Float8E5M2Type>())
-    return 8;
   if (isa<Float16Type, BFloat16Type>())
     return 16;
   if (isa<Float32Type>())
@@ -105,8 +103,6 @@ unsigned FloatType::getWidth() {
 
 /// Returns the floating semantics for the given type.
 const llvm::fltSemantics &FloatType::getFloatSemantics() {
-  if (isa<Float8E5M2Type>())
-    return APFloat::Float8E5M2();
   if (isa<BFloat16Type>())
     return APFloat::BFloat();
   if (isa<Float16Type>())

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 7ddcc2ff11d92..3d41823eb6c16 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -206,7 +206,6 @@ class MLIRContextImpl {
   StorageUniquer typeUniquer;
 
   /// Cached Type Instances.
-  Float8E5M2Type f8E5M2Ty;
   BFloat16Type bf16Ty;
   Float16Type f16Ty;
   Float32Type f32Ty;
@@ -277,7 +276,6 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
 
   //// Types.
   /// Floating-point Types.
-  impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
   impl->f32Ty = TypeUniquer::get<Float32Type>(this);
@@ -842,9 +840,6 @@ AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) {
 /// This should not be used directly.
 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
 
-Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
-  return context->getImpl().f8E5M2Ty;
-}
 BFloat16Type BFloat16Type::get(MLIRContext *context) {
   return context->getImpl().bf16Ty;
 }

diff  --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index b97388bf33f52..defe2dacfac29 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -18,7 +18,6 @@ using namespace mlir::detail;
 
 MLIRContext *Type::getContext() const { return getDialect().getContext(); }
 
-bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
 bool Type::isBF16() const { return isa<BFloat16Type>(); }
 bool Type::isF16() const { return isa<Float16Type>(); }
 bool Type::isF32() const { return isa<Float32Type>(); }

diff  --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 1bdbdc25bdcac..f27e53e334978 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -31,42 +31,6 @@ func.func @any_attr_of_fail() {
 
 // -----
 
-//===----------------------------------------------------------------------===//
-// Test float attributes
-//===----------------------------------------------------------------------===//
-
-func.func @float_attrs_pass() {
-  "test.float_attrs"() {
-    // CHECK: float_attr = 2.000000e+00 : f8E5M2
-    float_attr = 2. : f8E5M2
-  } : () -> ()
-  "test.float_attrs"() {
-    // CHECK: float_attr = 2.000000e+00 : f16
-    float_attr = 2. : f16
-  } : () -> ()
-  "test.float_attrs"() {
-    // CHECK: float_attr = 2.000000e+00 : bf16
-    float_attr = 2. : bf16
-  } : () -> ()
-  "test.float_attrs"() {
-    // CHECK: float_attr = 2.000000e+00 : f32
-    float_attr = 2. : f32
-  } : () -> ()
-  "test.float_attrs"() {
-    // CHECK: float_attr = 2.000000e+00 : f64
-    float_attr = 2. : f64
-  } : () -> ()
-  "test.float_attrs"() {
-    // CHECK: float_attr = 2.000000e+00 : f80
-    float_attr = 2. : f80
-  } : () -> ()
-  "test.float_attrs"() {
-    // CHECK: float_attr = 2.000000e+00 : f128
-    float_attr = 2. : f128
-  } : () -> ()
-  return
-}
-
 //===----------------------------------------------------------------------===//
 // Test integer attributes
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 3d6acb89ca68d..0a9c6beb39786 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -193,14 +193,6 @@ def TypeStringAttrWithTypeOp : TEST_Op<"string_attr_with_type"> {
   let assemblyFormat = "$attr attr-dict";
 }
 
-def FloatAttrOp : TEST_Op<"float_attrs"> {
-  // TODO: Clean up the OpBase float type and attribute selectors so they
-  // can express all of the types.
-  let arguments = (ins
-    AnyAttr:$float_attr
-  );
-}
-
 def I32Case5:  I32EnumAttrCase<"case5", 5>;
 def I32Case10: I32EnumAttrCase<"case10", 10>;
 


        


More information about the llvm-commits mailing list