[llvm] [IR][Float8] Add two kinds float8 IR type (PR #89900)

via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 24 02:37:16 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-ir

Author: None (JinjinLi868)

<details>
<summary>Changes</summary>

Support two classes Float8(float8e5m2 and float8e4m3fn) IR type for
ML. Float8e5m2 has a 5-bit exponent and a 2-bit mantissa and behaves
like an IEEE 754 floating point IR type. Float8e4m3fn has a 4-bit
exponent and a 3-bit mantissa.


---

Patch is 46.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89900.diff


23 Files Affected:

- (modified) llvm/docs/BitCodeFormat.rst (+16) 
- (modified) llvm/docs/LangRef.rst (+22-11) 
- (modified) llvm/include/llvm-c/Core.h (+15) 
- (modified) llvm/include/llvm/Bitcode/LLVMBitCodes.h (+3) 
- (modified) llvm/include/llvm/IR/Constants.h (+2) 
- (modified) llvm/include/llvm/IR/DataLayout.h (+3) 
- (modified) llvm/include/llvm/IR/IRBuilder.h (+10) 
- (modified) llvm/include/llvm/IR/Type.h (+17) 
- (modified) llvm/lib/AsmParser/LLLexer.cpp (+36-21) 
- (modified) llvm/lib/AsmParser/LLParser.cpp (+10-3) 
- (modified) llvm/lib/Bitcode/Reader/BitcodeReader.cpp (+25-1) 
- (modified) llvm/lib/Bitcode/Writer/BitcodeWriter.cpp (+17-15) 
- (modified) llvm/lib/CodeGen/MIRParser/MILexer.cpp (+1-1) 
- (modified) llvm/lib/IR/AsmWriter.cpp (+25-14) 
- (modified) llvm/lib/IR/Constants.cpp (+70-10) 
- (modified) llvm/lib/IR/Core.cpp (+18-3) 
- (modified) llvm/lib/IR/DataLayout.cpp (+3) 
- (modified) llvm/lib/IR/Function.cpp (+13-11) 
- (modified) llvm/lib/IR/LLVMContextImpl.cpp (+1) 
- (modified) llvm/lib/IR/LLVMContextImpl.h (+2-2) 
- (modified) llvm/lib/IR/Type.cpp (+30-14) 
- (added) llvm/test/Assembler/float8.ll (+71) 
- (modified) llvm/tools/llvm-c-test/echo.cpp (+5-1) 


``````````diff
diff --git a/llvm/docs/BitCodeFormat.rst b/llvm/docs/BitCodeFormat.rst
index 46af2e421a258c..bd9b1f87422585 100644
--- a/llvm/docs/BitCodeFormat.rst
+++ b/llvm/docs/BitCodeFormat.rst
@@ -1139,6 +1139,22 @@ TYPE_CODE_VOID Record
 
 The ``VOID`` record (code 2) adds a ``void`` type to the type table.
 
+TYPE_CODE_Float8E5M2 Record
+^^^^^^^^^^^^^^^^^^^^^
+
+``[Float8E5M2]``
+
+The ``Float8E5M2`` record (code 27) adds a ``float8e5m2`` (8-bit floating point)
+type to the type table.
+
+TYPE_CODE_Float8E4M3FN Record
+^^^^^^^^^^^^^^^^^^^^^
+
+``[Float8E4M3FN]``
+
+The ``Float8E4M3FN`` record (code 28) adds a ``float8e4m3fn`` (8-bit floating
+point) type to the type table.
+
 TYPE_CODE_HALF Record
 ^^^^^^^^^^^^^^^^^^^^^
 
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 9592929d79feb4..3106dc0cc25d5e 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -3847,6 +3847,14 @@ Floating-Point Types
    * - Type
      - Description
 
+   * - ``float8e5m2``
+     - 16-bit floating-point value(2-bit significand)
+
+   * - ``float8e4m3fn``
+     - 16-bit floating-point value(3-bit significand), there are no infinity
+       values, and NaN is represented with the exponent and mantissa bits set
+       to all 1s
+
    * - ``half``
      - 16-bit floating-point value
 
@@ -3871,9 +3879,9 @@ Floating-Point Types
    * - ``ppc_fp128``
      - 128-bit floating-point value (two 64-bits)
 
-The binary format of half, float, double, and fp128 correspond to the
-IEEE-754-2008 specifications for binary16, binary32, binary64, and binary128
-respectively.
+The binary format of float8e5m2, half, float, double, and fp128 correspond
+to the IEEE-754-2008 specifications for binary8, binary16, binary32, binary64,
+and binary128 respectively.
 
 X86_amx Type
 """"""""""""
@@ -4329,20 +4337,23 @@ number of digits. For example, NaN's, infinities, and other special
 values are represented in their IEEE hexadecimal format so that assembly
 and disassembly do not cause any bits to change in the constants.
 
-When using the hexadecimal form, constants of types bfloat, half, float, and
-double are represented using the 16-digit form shown above (which matches the
-IEEE754 representation for double); bfloat, half and float values must, however,
-be exactly representable as bfloat, IEEE 754 half, and IEEE 754 single
+When using the hexadecimal form, constants of types float8e5m2, float8e4m3fn,
+bfloat, half, float, and double are represented using the 16-digit form shown
+above (which matches the IEEE754 representation for double); float8e5m2,
+float8e4m3fn, bfloat, half and float values must, however, be exactly representable
+as float8e5m2, float8e4m3fn, bfloat, IEEE 754 half, and IEEE 754 single
 precision respectively. Hexadecimal format is always used for long double, and
 there are three forms of long double. The 80-bit format used by x86 is
 represented as ``0xK`` followed by 20 hexadecimal digits. The 128-bit format
 used by PowerPC (two adjacent doubles) is represented by ``0xM`` followed by 32
 hexadecimal digits. The IEEE 128-bit format is represented by ``0xL`` followed
 by 32 hexadecimal digits. Long doubles will only work if they match the long
-double format on your target.  The IEEE 16-bit format (half precision) is
-represented by ``0xH`` followed by 4 hexadecimal digits. The bfloat 16-bit
-format is represented by ``0xR`` followed by 4 hexadecimal digits. All
-hexadecimal formats are big-endian (sign bit at the left).
+double format on your target. The IEEE 8-bit format (floate5m2 precision) is
+represented by ``0xS`` followed by 2 hexadecimal digits. The float8e4m3fn 8-bit
+format is represented by ``0xQ`` followed by 2 hexadecimal digits. The IEEE 16-bit
+format (half precision) is represented by ``0xH`` followed by 4 hexadecimal digits.
+The bfloat 16-bit format is represented by ``0xR`` followed by 4 hexadecimal digits.
+All hexadecimal formats are big-endian (sign bit at the left).
 
 There are no constants of type x86_mmx and x86_amx.
 
diff --git a/llvm/include/llvm-c/Core.h b/llvm/include/llvm-c/Core.h
index 0b03f3b36fcdd3..7cc958beccb62d 100644
--- a/llvm/include/llvm-c/Core.h
+++ b/llvm/include/llvm-c/Core.h
@@ -167,6 +167,8 @@ typedef enum {
   LLVMBFloatTypeKind,    /**< 16 bit brain floating point type */
   LLVMX86_AMXTypeKind,   /**< X86 AMX */
   LLVMTargetExtTypeKind, /**< Target extension type */
+  LLVMFloat8E5M2TypeKind, /**< 8 bit floating point with 2 bit mantissa */
+  LLVMFloat8E4M3FNTypeKind, /**< 8 bit floating point with 3 bit mantissa */
 } LLVMTypeKind;
 
 typedef enum {
@@ -1298,6 +1300,17 @@ unsigned LLVMGetIntTypeWidth(LLVMTypeRef IntegerTy);
  * @{
  */
 
+
+/**
+ * Obtain a 8-bit floating point type from a context.
+ */
+LLVMTypeRef LLVMFloat8E5M2TypeInContext(LLVMContextRef C);
+
+/**
+ * Obtain a 8-bit floating point type from a context.
+ */
+LLVMTypeRef LLVMFloat8E4M3FNTypeInContext(LLVMContextRef C);
+
 /**
  * Obtain a 16-bit floating point type from a context.
  */
@@ -1339,6 +1352,8 @@ LLVMTypeRef LLVMPPCFP128TypeInContext(LLVMContextRef C);
  *
  * These map to the functions in this group of the same name.
  */
+LLVMTypeRef LLVMFloat8E5M2Type(void);
+LLVMTypeRef LLVMFloat8E4M3FNType(void);
 LLVMTypeRef LLVMHalfType(void);
 LLVMTypeRef LLVMBFloatType(void);
 LLVMTypeRef LLVMFloatType(void);
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 909eb833c601a9..ce6d639c2455c4 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -177,6 +177,9 @@ enum TypeCodes {
   TYPE_CODE_OPAQUE_POINTER = 25, // OPAQUE_POINTER: [addrspace]
 
   TYPE_CODE_TARGET_TYPE = 26, // TARGET_TYPE
+
+  TYPE_CODE_Float8E5M2 = 27, // Float8E5M2
+  TYPE_CODE_Float8E4M3FN = 28, // Float8E4M3FN
 };
 
 enum OperandBundleTagCode {
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 4290ef4486c6f4..3c82b74a111aa5 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -733,6 +733,7 @@ class ConstantDataArray final : public ConstantDataSequential {
   /// number of bits of the type contained in the passed in ArrayRef.
   /// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
   /// that this can return a ConstantAggregateZero object.
+  static Constant *getFP(Type *ElementType, ArrayRef<uint8_t> Elts);
   static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts);
   static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts);
   static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts);
@@ -805,6 +806,7 @@ class ConstantDataVector final : public ConstantDataSequential {
   /// number of bits of the type contained in the passed in ArrayRef.
   /// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
   /// that this can return a ConstantAggregateZero object.
+  static Constant *getFP(Type *ElementType, ArrayRef<uint8_t> Elts);
   static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts);
   static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts);
   static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts);
diff --git a/llvm/include/llvm/IR/DataLayout.h b/llvm/include/llvm/IR/DataLayout.h
index d14adfe1590be5..2f0c55d8e758c6 100644
--- a/llvm/include/llvm/IR/DataLayout.h
+++ b/llvm/include/llvm/IR/DataLayout.h
@@ -687,6 +687,9 @@ inline TypeSize DataLayout::getTypeSizeInBits(Type *Ty) const {
     return getStructLayout(cast<StructType>(Ty))->getSizeInBits();
   case Type::IntegerTyID:
     return TypeSize::getFixed(Ty->getIntegerBitWidth());
+  case Type::Float8E5M2TyID:
+  case Type::Float8E4M3FNTyID:
+    return TypeSize::getFixed(8);
   case Type::HalfTyID:
   case Type::BFloatTyID:
     return TypeSize::getFixed(16);
diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index b6534a1962a2f5..de981586b4cbe7 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -540,6 +540,16 @@ class IRBuilderBase {
     return Type::getIntNTy(Context, N);
   }
 
+  /// Fetch the type representing a 8-bit e5m2 floating point value.
+  Type *getFloat8E5M2Ty() {
+    return Type::getFloat8E5M2Ty(Context);
+  }
+
+  /// Fetch the type representing a 8-bit e4m3fn floating point value.
+  Type *getFloat8E4M3FNTy() {
+    return Type::getFloat8E4M3FNTy(Context);
+  }
+
   /// Fetch the type representing a 16-bit floating point value.
   Type *getHalfTy() {
     return Type::getHalfTy(Context);
diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h
index 1f0133c08e7d60..bf9f63d2cdda4c 100644
--- a/llvm/include/llvm/IR/Type.h
+++ b/llvm/include/llvm/IR/Type.h
@@ -55,6 +55,8 @@ class Type {
     // PrimitiveTypes
     HalfTyID = 0,  ///< 16-bit floating point type
     BFloatTyID,    ///< 16-bit floating point type (7-bit significand)
+    Float8E5M2TyID,   ///< 8-bit floating type (5 Bit exponent)
+    Float8E4M3FNTyID, ///< 8-bit floating type (4 Bit exponent)
     FloatTyID,     ///< 32-bit floating point type
     DoubleTyID,    ///< 64-bit floating point type
     X86_FP80TyID,  ///< 80-bit floating point type (X87)
@@ -139,6 +141,17 @@ class Type {
   /// Return true if this is 'void'.
   bool isVoidTy() const { return getTypeID() == VoidTyID; }
 
+  /// Return true if this is 'F8E5M2'.
+  bool isFloat8E5M2Ty() const { return getTypeID() == Float8E5M2TyID; }
+
+  /// Return true if this is 'F8E4M3FN'.
+  bool isFloat8E4M3FNTy() const { return getTypeID() == Float8E4M3FNTyID; }
+
+  /// Return true if this is an 8-bit float type.
+  bool is8BitFPTy() const {
+    return getTypeID() == Float8E5M2TyID || getTypeID() == Float8E4M3FNTyID;
+  }
+
   /// Return true if this is 'half', a 16-bit IEEE fp type.
   bool isHalfTy() const { return getTypeID() == HalfTyID; }
 
@@ -174,6 +187,8 @@ class Type {
     case FloatTyID:
     case HalfTyID:
     case BFloatTyID:
+    case Float8E5M2TyID:
+    case Float8E4M3FNTyID:
     case FP128TyID:
       return true;
     default:
@@ -445,6 +460,8 @@ class Type {
   //
   static Type *getVoidTy(LLVMContext &C);
   static Type *getLabelTy(LLVMContext &C);
+  static Type *getFloat8E5M2Ty(LLVMContext &C);
+  static Type *getFloat8E4M3FNTy(LLVMContext &C);
   static Type *getHalfTy(LLVMContext &C);
   static Type *getBFloatTy(LLVMContext &C);
   static Type *getFloatTy(LLVMContext &C);
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 8ded07ffd8bd25..aabd6262304f16 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -825,20 +825,22 @@ lltok::Kind LLLexer::LexIdentifier() {
     }                                                                          \
   } while (false)
 
-  TYPEKEYWORD("void",      Type::getVoidTy(Context));
-  TYPEKEYWORD("half",      Type::getHalfTy(Context));
-  TYPEKEYWORD("bfloat",    Type::getBFloatTy(Context));
-  TYPEKEYWORD("float",     Type::getFloatTy(Context));
-  TYPEKEYWORD("double",    Type::getDoubleTy(Context));
-  TYPEKEYWORD("x86_fp80",  Type::getX86_FP80Ty(Context));
-  TYPEKEYWORD("fp128",     Type::getFP128Ty(Context));
-  TYPEKEYWORD("ppc_fp128", Type::getPPC_FP128Ty(Context));
-  TYPEKEYWORD("label",     Type::getLabelTy(Context));
-  TYPEKEYWORD("metadata",  Type::getMetadataTy(Context));
-  TYPEKEYWORD("x86_mmx",   Type::getX86_MMXTy(Context));
-  TYPEKEYWORD("x86_amx",   Type::getX86_AMXTy(Context));
-  TYPEKEYWORD("token",     Type::getTokenTy(Context));
-  TYPEKEYWORD("ptr",       PointerType::getUnqual(Context));
+  TYPEKEYWORD("void",          Type::getVoidTy(Context));
+  TYPEKEYWORD("float8e5m2",    Type::getFloat8E5M2Ty(Context));
+  TYPEKEYWORD("float8e4m3fn",  Type::getFloat8E4M3FNTy(Context));
+  TYPEKEYWORD("half",          Type::getHalfTy(Context));
+  TYPEKEYWORD("bfloat",        Type::getBFloatTy(Context));
+  TYPEKEYWORD("float",         Type::getFloatTy(Context));
+  TYPEKEYWORD("double",        Type::getDoubleTy(Context));
+  TYPEKEYWORD("x86_fp80",      Type::getX86_FP80Ty(Context));
+  TYPEKEYWORD("fp128",         Type::getFP128Ty(Context));
+  TYPEKEYWORD("ppc_fp128",     Type::getPPC_FP128Ty(Context));
+  TYPEKEYWORD("label",         Type::getLabelTy(Context));
+  TYPEKEYWORD("metadata",      Type::getMetadataTy(Context));
+  TYPEKEYWORD("x86_mmx",       Type::getX86_MMXTy(Context));
+  TYPEKEYWORD("x86_amx",       Type::getX86_AMXTy(Context));
+  TYPEKEYWORD("token",         Type::getTokenTy(Context));
+  TYPEKEYWORD("ptr",           PointerType::getUnqual(Context));
 
 #undef TYPEKEYWORD
 
@@ -1006,18 +1008,21 @@ lltok::Kind LLLexer::LexIdentifier() {
 
 /// Lex all tokens that start with a 0x prefix, knowing they match and are not
 /// labels.
-///    HexFPConstant     0x[0-9A-Fa-f]+
-///    HexFP80Constant   0xK[0-9A-Fa-f]+
-///    HexFP128Constant  0xL[0-9A-Fa-f]+
-///    HexPPC128Constant 0xM[0-9A-Fa-f]+
-///    HexHalfConstant   0xH[0-9A-Fa-f]+
-///    HexBFloatConstant 0xR[0-9A-Fa-f]+
+///    HexFPConstant         0x[0-9A-Fa-f]+
+///    HexFP80Constant       0xK[0-9A-Fa-f]+
+///    HexFP128Constant      0xL[0-9A-Fa-f]+
+///    HexPPC128Constant     0xM[0-9A-Fa-f]+
+///    HexHalfConstant       0xH[0-9A-Fa-f]+
+///    HexBFloatConstant     0xR[0-9A-Fa-f]+
+///    HexFP8E4M3FNConstant  0xQ[0-9A-Fa-f]+
+///    HexFP8E5M2Constant    0xS[0-9A-Fa-f]+
+
 lltok::Kind LLLexer::Lex0x() {
   CurPtr = TokStart + 2;
 
   char Kind;
   if ((CurPtr[0] >= 'K' && CurPtr[0] <= 'M') || CurPtr[0] == 'H' ||
-      CurPtr[0] == 'R') {
+      CurPtr[0] == 'R' || CurPtr[0] == 'Q' || CurPtr[0] == 'S') {
     Kind = *CurPtr++;
   } else {
     Kind = 'J';
@@ -1068,6 +1073,16 @@ lltok::Kind LLLexer::Lex0x() {
     APFloatVal = APFloat(APFloat::BFloat(),
                          APInt(16, HexIntToVal(TokStart + 3, CurPtr)));
     return lltok::APFloat;
+  case 'Q':
+    // FP8E4M3FN
+    APFloatVal = APFloat(APFloat::Float8E4M3FN(),
+                         APInt(8, HexIntToVal(TokStart + 1, CurPtr)));
+    return lltok::APFloat;
+  case 'S':
+    // FP8E5M2
+    APFloatVal = APFloat(APFloat::Float8E5M2(),
+                         APInt(8, HexIntToVal(TokStart + 1, CurPtr)));
+    return lltok::APFloat;
   }
 }
 
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 63104129f8c2df..d32e154c8baf66 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -5998,13 +5998,20 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
         !ConstantFP::isValueValidForType(Ty, ID.APFloatVal))
       return error(ID.Loc, "floating point constant invalid for type");
 
-    // The lexer has no type info, so builds all half, bfloat, float, and double
-    // FP constants as double.  Fix this here.  Long double does not need this.
+    // The lexer has no type info, so builds all float8e5m2, float8e4m3fn, half,
+    // bfloat, float, and double FP constants as double.  Fix this here. Long
+    // double does not need this.
     if (&ID.APFloatVal.getSemantics() == &APFloat::IEEEdouble()) {
       // Check for signaling before potentially converting and losing that info.
       bool IsSNAN = ID.APFloatVal.isSignaling();
       bool Ignored;
-      if (Ty->isHalfTy())
+      if (Ty->isFloat8E5M2Ty())
+        ID.APFloatVal.convert(APFloat::Float8E5M2(), APFloat::rmNearestTiesToEven,
+                              &Ignored);
+      else if (Ty->isFloat8E4M3FNTy())
+        ID.APFloatVal.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+                              &Ignored);
+      else if (Ty->isHalfTy())
         ID.APFloatVal.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven,
                               &Ignored);
       else if (Ty->isBFloatTy())
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 0b7fcd88418894..3e3ec1b2664089 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -2404,6 +2404,12 @@ Error BitcodeReader::parseTypeTableBody() {
     case bitc::TYPE_CODE_VOID:      // VOID
       ResultTy = Type::getVoidTy(Context);
       break;
+    case bitc::TYPE_CODE_Float8E4M3FN: // FP8E4M3FN
+      ResultTy = Type::getFloat8E4M3FNTy(Context);
+      break;
+    case bitc::TYPE_CODE_Float8E5M2:   // FP8E5M2
+      ResultTy = Type::getFloat8E5M2Ty(Context);
+      break;
     case bitc::TYPE_CODE_HALF:     // HALF
       ResultTy = Type::getHalfTy(Context);
       break;
@@ -3138,7 +3144,13 @@ Error BitcodeReader::parseConstants() {
         return error("Invalid float const record");
 
       auto *ScalarTy = CurTy->getScalarType();
-      if (ScalarTy->isHalfTy())
+      if (ScalarTy->isFloat8E4M3FNTy())
+        V = ConstantFP::get(Context, APFloat(APFloat::Float8E4M3FN(),
+                                             APInt(8, (uint8_t)Record[0])));
+      else if (ScalarTy->isFloat8E5M2Ty())
+        V = ConstantFP::get(Context, APFloat(APFloat::Float8E5M2(),
+                                             APInt(8, (uint8_t)Record[0])));
+      else if (ScalarTy->isHalfTy())
         V = ConstantFP::get(CurTy, APFloat(APFloat::IEEEhalf(),
                                            APInt(16, (uint16_t)Record[0])));
       else if (ScalarTy->isBFloatTy())
@@ -3234,6 +3246,18 @@ Error BitcodeReader::parseConstants() {
           V = ConstantDataVector::get(Context, Elts);
         else
           V = ConstantDataArray::get(Context, Elts);
+      } else if (EltTy->isFloat8E4M3FNTy()) {
+        SmallVector<uint8_t, 16> Elts(Record.begin(), Record.end());
+        if (isa<VectorType>(CurTy))
+          V = ConstantDataVector::getFP(EltTy, Elts);
+        else
+          V = ConstantDataArray::getFP(EltTy, Elts);
+      } else if (EltTy->isFloat8E5M2Ty()) {
+        SmallVector<uint8_t, 16> Elts(Record.begin(), Record.end());
+        if (isa<VectorType>(CurTy))
+          V = ConstantDataVector::getFP(EltTy, Elts);
+        else
+          V = ConstantDataArray::getFP(EltTy, Elts);
       } else if (EltTy->isHalfTy()) {
         SmallVector<uint16_t, 16> Elts(Record.begin(), Record.end());
         if (isa<VectorType>(CurTy))
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 6d01e3b4d82189..46b5d0ed9440ee 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -1043,19 +1043,21 @@ void ModuleBitcodeWriter::writeTypeTable() {
     unsigned Code = 0;
 
     switch (T->getTypeID()) {
-    case Type::VoidTyID:      Code = bitc::TYPE_CODE_VOID;      break;
-    case Type::HalfTyID:      Code = bitc::TYPE_CODE_HALF;      break;
-    case Type::BFloatTyID:    Code = bitc::TYPE_CODE_BFLOAT;    break;
-    case Type::FloatTyID:     Code = bitc::TYPE_CODE_FLOAT;     break;
-    case Type::DoubleTyID:    Code = bitc::TYPE_CODE_DOUBLE;    break;
-    case Type::X86_FP80TyID:  Code = bitc::TYPE_CODE_X86_FP80;  break;
-    case Type::FP128TyID:     Code = bitc::TYPE_CODE_FP128;     break;
-    case Type::PPC_FP128TyID: Code = bitc::TYPE_CODE_PPC_FP128; break;
-    case Type::LabelTyID:     Code = bitc::TYPE_CODE_LABEL;     break;
-    case Type::MetadataTyID:  Code = bitc::TYPE_CODE_METADATA;  break;
-    case Type::X86_MMXTyID:   Code = bitc::TYPE_CODE_X86_MMX;   break;
-    case Type::X86_AMXTyID:   Code = bitc::TYPE_CODE_X86_AMX;   break;
-    case Type::TokenTyID:     Code = bitc::TYPE_CODE_TOKEN;     break;
+    case Type::VoidTyID:          Code = bitc::TYPE_CODE_VOID;          break;
+    case Type::Float8E4M3FNTyID:  Code = bitc::TYPE_CODE_Float8E4M3FN;  break;
+    case Type::Float8E5M2TyID:    Code = bitc::TYPE_CODE_Float8E5M2;    break;
+    case Type::HalfTyID:          Code = bitc::TYPE_CODE_HALF;          break;
+    case Type::BFloatTyID:        Code = bitc::TYPE_CODE_BFLOAT;        break;
+    case Type::FloatTyID:         Code = bitc::TYPE_CODE_FLOAT;         break;
+    case Type::DoubleTyID:        Code = bitc::TYPE_CODE_DOUBLE;        break;
+    case Type::X86_FP80TyID:      Code = bitc::TYPE_CODE_X86_FP80;      break;
+    case Type::FP128TyID:         Code = bitc::TYPE_CODE_FP128;         break;
+    case Type::PPC_FP128TyID:     Code = bitc::TYPE_CODE_PPC_FP128;     break;
+    case Type::LabelTyID:         Code = bitc::TYPE_CODE_LABEL;         break;
+    case Type::MetadataTyID:      Code = bitc::TYPE_CODE_M...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/89900


More information about the llvm-commits mailing list