[llvm] [IR][Float8] Add two kinds float8 IR type (PR #89900)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 25 22:53:50 PDT 2024


https://github.com/JinjinLi868 updated https://github.com/llvm/llvm-project/pull/89900

>From 6866553d1d390b02fa4d776a5691ce1b14e12864 Mon Sep 17 00:00:00 2001
From: Jinjin Li <lijinjin.868 at bytedance.com>
Date: Tue, 23 Apr 2024 18:07:11 +0800
Subject: [PATCH] [IR][Float8] Add two kinds float8 IR type

Support two classes Float8(float8e5m2 and float8e4m3fn) IR type for
ML. Float8e5m2 has a 5-bit exponent and a 2-bit mantissa and behaves
like an IEEE 754 floating point IR type. Float8e4m3fn has a 4-bit
exponent and a 3-bit mantissa.
---
 llvm/docs/BitCodeFormat.rst               |  16 ++++
 llvm/docs/LangRef.rst                     |  33 ++++---
 llvm/include/llvm-c/Core.h                |  15 ++++
 llvm/include/llvm/Bitcode/LLVMBitCodes.h  |   3 +
 llvm/include/llvm/IR/Constants.h          |   2 +
 llvm/include/llvm/IR/DataLayout.h         |   3 +
 llvm/include/llvm/IR/IRBuilder.h          |  10 +++
 llvm/include/llvm/IR/Type.h               |  17 ++++
 llvm/lib/AsmParser/LLLexer.cpp            |  57 +++++++-----
 llvm/lib/AsmParser/LLParser.cpp           |  13 ++-
 llvm/lib/Bitcode/Reader/BitcodeReader.cpp |  26 +++++-
 llvm/lib/Bitcode/Writer/BitcodeWriter.cpp |  32 +++----
 llvm/lib/CodeGen/MIRParser/MILexer.cpp    |   2 +-
 llvm/lib/IR/AsmWriter.cpp                 |  39 ++++++---
 llvm/lib/IR/Constants.cpp                 |  80 ++++++++++++++---
 llvm/lib/IR/Core.cpp                      |  21 ++++-
 llvm/lib/IR/DataLayout.cpp                |   3 +
 llvm/lib/IR/Function.cpp                  |  24 ++---
 llvm/lib/IR/LLVMContextImpl.cpp           |   1 +
 llvm/lib/IR/LLVMContextImpl.h             |   4 +-
 llvm/lib/IR/Type.cpp                      |  44 +++++++---
 llvm/test/Assembler/float8.ll             | 101 ++++++++++++++++++++++
 llvm/tools/llvm-c-test/echo.cpp           |   6 +-
 23 files changed, 445 insertions(+), 107 deletions(-)
 create mode 100644 llvm/test/Assembler/float8.ll

diff --git a/llvm/docs/BitCodeFormat.rst b/llvm/docs/BitCodeFormat.rst
index 46af2e421a258c..bd9b1f87422585 100644
--- a/llvm/docs/BitCodeFormat.rst
+++ b/llvm/docs/BitCodeFormat.rst
@@ -1139,6 +1139,22 @@ TYPE_CODE_VOID Record
 
 The ``VOID`` record (code 2) adds a ``void`` type to the type table.
 
+TYPE_CODE_Float8E5M2 Record
+^^^^^^^^^^^^^^^^^^^^^
+
+``[Float8E5M2]``
+
+The ``Float8E5M2`` record (code 27) adds a ``float8e5m2`` (8-bit floating point)
+type to the type table.
+
+TYPE_CODE_Float8E4M3FN Record
+^^^^^^^^^^^^^^^^^^^^^
+
+``[Float8E4M3FN]``
+
+The ``Float8E4M3FN`` record (code 28) adds a ``float8e4m3fn`` (8-bit floating
+point) type to the type table.
+
 TYPE_CODE_HALF Record
 ^^^^^^^^^^^^^^^^^^^^^
 
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 9592929d79feb4..bb8b3466531f77 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -3847,6 +3847,14 @@ Floating-Point Types
    * - Type
      - Description
 
+   * - ``float8e5m2``
+     - 8-bit floating-point value(2-bit significand)
+
+   * - ``float8e4m3fn``
+     - 8-bit floating-point value(3-bit significand), there are no infinity
+       values, and NaN is represented with the exponent and mantissa bits set
+       to all 1s
+
    * - ``half``
      - 16-bit floating-point value
 
@@ -3871,9 +3879,9 @@ Floating-Point Types
    * - ``ppc_fp128``
      - 128-bit floating-point value (two 64-bits)
 
-The binary format of half, float, double, and fp128 correspond to the
-IEEE-754-2008 specifications for binary16, binary32, binary64, and binary128
-respectively.
+The binary format of float8e5m2, half, float, double, and fp128 correspond
+to the IEEE-754-2008 specifications for binary8, binary16, binary32, binary64,
+and binary128 respectively.
 
 X86_amx Type
 """"""""""""
@@ -4329,20 +4337,23 @@ number of digits. For example, NaN's, infinities, and other special
 values are represented in their IEEE hexadecimal format so that assembly
 and disassembly do not cause any bits to change in the constants.
 
-When using the hexadecimal form, constants of types bfloat, half, float, and
-double are represented using the 16-digit form shown above (which matches the
-IEEE754 representation for double); bfloat, half and float values must, however,
-be exactly representable as bfloat, IEEE 754 half, and IEEE 754 single
+When using the hexadecimal form, constants of types float8e5m2, float8e4m3fn,
+bfloat, half, float, and double are represented using the 16-digit form shown
+above (which matches the IEEE754 representation for double); float8e5m2,
+float8e4m3fn, bfloat, half and float values must, however, be exactly representable
+as float8e5m2, float8e4m3fn, bfloat, IEEE 754 half, and IEEE 754 single
 precision respectively. Hexadecimal format is always used for long double, and
 there are three forms of long double. The 80-bit format used by x86 is
 represented as ``0xK`` followed by 20 hexadecimal digits. The 128-bit format
 used by PowerPC (two adjacent doubles) is represented by ``0xM`` followed by 32
 hexadecimal digits. The IEEE 128-bit format is represented by ``0xL`` followed
 by 32 hexadecimal digits. Long doubles will only work if they match the long
-double format on your target.  The IEEE 16-bit format (half precision) is
-represented by ``0xH`` followed by 4 hexadecimal digits. The bfloat 16-bit
-format is represented by ``0xR`` followed by 4 hexadecimal digits. All
-hexadecimal formats are big-endian (sign bit at the left).
+double format on your target. The IEEE 8-bit format (floate5m2 precision) is
+represented by ``0xS`` followed by 2 hexadecimal digits. The float8e4m3fn 8-bit
+format is represented by ``0xQ`` followed by 2 hexadecimal digits. The IEEE 16-bit
+format (half precision) is represented by ``0xH`` followed by 4 hexadecimal digits.
+The bfloat 16-bit format is represented by ``0xR`` followed by 4 hexadecimal digits.
+All hexadecimal formats are big-endian (sign bit at the left).
 
 There are no constants of type x86_mmx and x86_amx.
 
diff --git a/llvm/include/llvm-c/Core.h b/llvm/include/llvm-c/Core.h
index 0b03f3b36fcdd3..7cc958beccb62d 100644
--- a/llvm/include/llvm-c/Core.h
+++ b/llvm/include/llvm-c/Core.h
@@ -167,6 +167,8 @@ typedef enum {
   LLVMBFloatTypeKind,    /**< 16 bit brain floating point type */
   LLVMX86_AMXTypeKind,   /**< X86 AMX */
   LLVMTargetExtTypeKind, /**< Target extension type */
+  LLVMFloat8E5M2TypeKind, /**< 8 bit floating point with 2 bit mantissa */
+  LLVMFloat8E4M3FNTypeKind, /**< 8 bit floating point with 3 bit mantissa */
 } LLVMTypeKind;
 
 typedef enum {
@@ -1298,6 +1300,17 @@ unsigned LLVMGetIntTypeWidth(LLVMTypeRef IntegerTy);
  * @{
  */
 
+
+/**
+ * Obtain a 8-bit floating point type from a context.
+ */
+LLVMTypeRef LLVMFloat8E5M2TypeInContext(LLVMContextRef C);
+
+/**
+ * Obtain a 8-bit floating point type from a context.
+ */
+LLVMTypeRef LLVMFloat8E4M3FNTypeInContext(LLVMContextRef C);
+
 /**
  * Obtain a 16-bit floating point type from a context.
  */
@@ -1339,6 +1352,8 @@ LLVMTypeRef LLVMPPCFP128TypeInContext(LLVMContextRef C);
  *
  * These map to the functions in this group of the same name.
  */
+LLVMTypeRef LLVMFloat8E5M2Type(void);
+LLVMTypeRef LLVMFloat8E4M3FNType(void);
 LLVMTypeRef LLVMHalfType(void);
 LLVMTypeRef LLVMBFloatType(void);
 LLVMTypeRef LLVMFloatType(void);
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 909eb833c601a9..ce6d639c2455c4 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -177,6 +177,9 @@ enum TypeCodes {
   TYPE_CODE_OPAQUE_POINTER = 25, // OPAQUE_POINTER: [addrspace]
 
   TYPE_CODE_TARGET_TYPE = 26, // TARGET_TYPE
+
+  TYPE_CODE_Float8E5M2 = 27, // Float8E5M2
+  TYPE_CODE_Float8E4M3FN = 28, // Float8E4M3FN
 };
 
 enum OperandBundleTagCode {
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 4290ef4486c6f4..3c82b74a111aa5 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -733,6 +733,7 @@ class ConstantDataArray final : public ConstantDataSequential {
   /// number of bits of the type contained in the passed in ArrayRef.
   /// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
   /// that this can return a ConstantAggregateZero object.
+  static Constant *getFP(Type *ElementType, ArrayRef<uint8_t> Elts);
   static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts);
   static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts);
   static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts);
@@ -805,6 +806,7 @@ class ConstantDataVector final : public ConstantDataSequential {
   /// number of bits of the type contained in the passed in ArrayRef.
   /// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
   /// that this can return a ConstantAggregateZero object.
+  static Constant *getFP(Type *ElementType, ArrayRef<uint8_t> Elts);
   static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts);
   static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts);
   static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts);
diff --git a/llvm/include/llvm/IR/DataLayout.h b/llvm/include/llvm/IR/DataLayout.h
index d14adfe1590be5..2f0c55d8e758c6 100644
--- a/llvm/include/llvm/IR/DataLayout.h
+++ b/llvm/include/llvm/IR/DataLayout.h
@@ -687,6 +687,9 @@ inline TypeSize DataLayout::getTypeSizeInBits(Type *Ty) const {
     return getStructLayout(cast<StructType>(Ty))->getSizeInBits();
   case Type::IntegerTyID:
     return TypeSize::getFixed(Ty->getIntegerBitWidth());
+  case Type::Float8E5M2TyID:
+  case Type::Float8E4M3FNTyID:
+    return TypeSize::getFixed(8);
   case Type::HalfTyID:
   case Type::BFloatTyID:
     return TypeSize::getFixed(16);
diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index b6534a1962a2f5..de981586b4cbe7 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -540,6 +540,16 @@ class IRBuilderBase {
     return Type::getIntNTy(Context, N);
   }
 
+  /// Fetch the type representing a 8-bit e5m2 floating point value.
+  Type *getFloat8E5M2Ty() {
+    return Type::getFloat8E5M2Ty(Context);
+  }
+
+  /// Fetch the type representing a 8-bit e4m3fn floating point value.
+  Type *getFloat8E4M3FNTy() {
+    return Type::getFloat8E4M3FNTy(Context);
+  }
+
   /// Fetch the type representing a 16-bit floating point value.
   Type *getHalfTy() {
     return Type::getHalfTy(Context);
diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h
index 1f0133c08e7d60..bf9f63d2cdda4c 100644
--- a/llvm/include/llvm/IR/Type.h
+++ b/llvm/include/llvm/IR/Type.h
@@ -55,6 +55,8 @@ class Type {
     // PrimitiveTypes
     HalfTyID = 0,  ///< 16-bit floating point type
     BFloatTyID,    ///< 16-bit floating point type (7-bit significand)
+    Float8E5M2TyID,   ///< 8-bit floating type (5 Bit exponent)
+    Float8E4M3FNTyID, ///< 8-bit floating type (4 Bit exponent)
     FloatTyID,     ///< 32-bit floating point type
     DoubleTyID,    ///< 64-bit floating point type
     X86_FP80TyID,  ///< 80-bit floating point type (X87)
@@ -139,6 +141,17 @@ class Type {
   /// Return true if this is 'void'.
   bool isVoidTy() const { return getTypeID() == VoidTyID; }
 
+  /// Return true if this is 'F8E5M2'.
+  bool isFloat8E5M2Ty() const { return getTypeID() == Float8E5M2TyID; }
+
+  /// Return true if this is 'F8E4M3FN'.
+  bool isFloat8E4M3FNTy() const { return getTypeID() == Float8E4M3FNTyID; }
+
+  /// Return true if this is an 8-bit float type.
+  bool is8BitFPTy() const {
+    return getTypeID() == Float8E5M2TyID || getTypeID() == Float8E4M3FNTyID;
+  }
+
   /// Return true if this is 'half', a 16-bit IEEE fp type.
   bool isHalfTy() const { return getTypeID() == HalfTyID; }
 
@@ -174,6 +187,8 @@ class Type {
     case FloatTyID:
     case HalfTyID:
     case BFloatTyID:
+    case Float8E5M2TyID:
+    case Float8E4M3FNTyID:
     case FP128TyID:
       return true;
     default:
@@ -445,6 +460,8 @@ class Type {
   //
   static Type *getVoidTy(LLVMContext &C);
   static Type *getLabelTy(LLVMContext &C);
+  static Type *getFloat8E5M2Ty(LLVMContext &C);
+  static Type *getFloat8E4M3FNTy(LLVMContext &C);
   static Type *getHalfTy(LLVMContext &C);
   static Type *getBFloatTy(LLVMContext &C);
   static Type *getFloatTy(LLVMContext &C);
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 8ded07ffd8bd25..aabd6262304f16 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -825,20 +825,22 @@ lltok::Kind LLLexer::LexIdentifier() {
     }                                                                          \
   } while (false)
 
-  TYPEKEYWORD("void",      Type::getVoidTy(Context));
-  TYPEKEYWORD("half",      Type::getHalfTy(Context));
-  TYPEKEYWORD("bfloat",    Type::getBFloatTy(Context));
-  TYPEKEYWORD("float",     Type::getFloatTy(Context));
-  TYPEKEYWORD("double",    Type::getDoubleTy(Context));
-  TYPEKEYWORD("x86_fp80",  Type::getX86_FP80Ty(Context));
-  TYPEKEYWORD("fp128",     Type::getFP128Ty(Context));
-  TYPEKEYWORD("ppc_fp128", Type::getPPC_FP128Ty(Context));
-  TYPEKEYWORD("label",     Type::getLabelTy(Context));
-  TYPEKEYWORD("metadata",  Type::getMetadataTy(Context));
-  TYPEKEYWORD("x86_mmx",   Type::getX86_MMXTy(Context));
-  TYPEKEYWORD("x86_amx",   Type::getX86_AMXTy(Context));
-  TYPEKEYWORD("token",     Type::getTokenTy(Context));
-  TYPEKEYWORD("ptr",       PointerType::getUnqual(Context));
+  TYPEKEYWORD("void",          Type::getVoidTy(Context));
+  TYPEKEYWORD("float8e5m2",    Type::getFloat8E5M2Ty(Context));
+  TYPEKEYWORD("float8e4m3fn",  Type::getFloat8E4M3FNTy(Context));
+  TYPEKEYWORD("half",          Type::getHalfTy(Context));
+  TYPEKEYWORD("bfloat",        Type::getBFloatTy(Context));
+  TYPEKEYWORD("float",         Type::getFloatTy(Context));
+  TYPEKEYWORD("double",        Type::getDoubleTy(Context));
+  TYPEKEYWORD("x86_fp80",      Type::getX86_FP80Ty(Context));
+  TYPEKEYWORD("fp128",         Type::getFP128Ty(Context));
+  TYPEKEYWORD("ppc_fp128",     Type::getPPC_FP128Ty(Context));
+  TYPEKEYWORD("label",         Type::getLabelTy(Context));
+  TYPEKEYWORD("metadata",      Type::getMetadataTy(Context));
+  TYPEKEYWORD("x86_mmx",       Type::getX86_MMXTy(Context));
+  TYPEKEYWORD("x86_amx",       Type::getX86_AMXTy(Context));
+  TYPEKEYWORD("token",         Type::getTokenTy(Context));
+  TYPEKEYWORD("ptr",           PointerType::getUnqual(Context));
 
 #undef TYPEKEYWORD
 
@@ -1006,18 +1008,21 @@ lltok::Kind LLLexer::LexIdentifier() {
 
 /// Lex all tokens that start with a 0x prefix, knowing they match and are not
 /// labels.
-///    HexFPConstant     0x[0-9A-Fa-f]+
-///    HexFP80Constant   0xK[0-9A-Fa-f]+
-///    HexFP128Constant  0xL[0-9A-Fa-f]+
-///    HexPPC128Constant 0xM[0-9A-Fa-f]+
-///    HexHalfConstant   0xH[0-9A-Fa-f]+
-///    HexBFloatConstant 0xR[0-9A-Fa-f]+
+///    HexFPConstant         0x[0-9A-Fa-f]+
+///    HexFP80Constant       0xK[0-9A-Fa-f]+
+///    HexFP128Constant      0xL[0-9A-Fa-f]+
+///    HexPPC128Constant     0xM[0-9A-Fa-f]+
+///    HexHalfConstant       0xH[0-9A-Fa-f]+
+///    HexBFloatConstant     0xR[0-9A-Fa-f]+
+///    HexFP8E4M3FNConstant  0xQ[0-9A-Fa-f]+
+///    HexFP8E5M2Constant    0xS[0-9A-Fa-f]+
+
 lltok::Kind LLLexer::Lex0x() {
   CurPtr = TokStart + 2;
 
   char Kind;
   if ((CurPtr[0] >= 'K' && CurPtr[0] <= 'M') || CurPtr[0] == 'H' ||
-      CurPtr[0] == 'R') {
+      CurPtr[0] == 'R' || CurPtr[0] == 'Q' || CurPtr[0] == 'S') {
     Kind = *CurPtr++;
   } else {
     Kind = 'J';
@@ -1068,6 +1073,16 @@ lltok::Kind LLLexer::Lex0x() {
     APFloatVal = APFloat(APFloat::BFloat(),
                          APInt(16, HexIntToVal(TokStart + 3, CurPtr)));
     return lltok::APFloat;
+  case 'Q':
+    // FP8E4M3FN
+    APFloatVal = APFloat(APFloat::Float8E4M3FN(),
+                         APInt(8, HexIntToVal(TokStart + 1, CurPtr)));
+    return lltok::APFloat;
+  case 'S':
+    // FP8E5M2
+    APFloatVal = APFloat(APFloat::Float8E5M2(),
+                         APInt(8, HexIntToVal(TokStart + 1, CurPtr)));
+    return lltok::APFloat;
   }
 }
 
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 63104129f8c2df..d32e154c8baf66 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -5998,13 +5998,20 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
         !ConstantFP::isValueValidForType(Ty, ID.APFloatVal))
       return error(ID.Loc, "floating point constant invalid for type");
 
-    // The lexer has no type info, so builds all half, bfloat, float, and double
-    // FP constants as double.  Fix this here.  Long double does not need this.
+    // The lexer has no type info, so builds all float8e5m2, float8e4m3fn, half,
+    // bfloat, float, and double FP constants as double.  Fix this here. Long
+    // double does not need this.
     if (&ID.APFloatVal.getSemantics() == &APFloat::IEEEdouble()) {
       // Check for signaling before potentially converting and losing that info.
       bool IsSNAN = ID.APFloatVal.isSignaling();
       bool Ignored;
-      if (Ty->isHalfTy())
+      if (Ty->isFloat8E5M2Ty())
+        ID.APFloatVal.convert(APFloat::Float8E5M2(), APFloat::rmNearestTiesToEven,
+                              &Ignored);
+      else if (Ty->isFloat8E4M3FNTy())
+        ID.APFloatVal.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+                              &Ignored);
+      else if (Ty->isHalfTy())
         ID.APFloatVal.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven,
                               &Ignored);
       else if (Ty->isBFloatTy())
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 0b7fcd88418894..3e3ec1b2664089 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -2404,6 +2404,12 @@ Error BitcodeReader::parseTypeTableBody() {
     case bitc::TYPE_CODE_VOID:      // VOID
       ResultTy = Type::getVoidTy(Context);
       break;
+    case bitc::TYPE_CODE_Float8E4M3FN: // FP8E4M3FN
+      ResultTy = Type::getFloat8E4M3FNTy(Context);
+      break;
+    case bitc::TYPE_CODE_Float8E5M2:   // FP8E5M2
+      ResultTy = Type::getFloat8E5M2Ty(Context);
+      break;
     case bitc::TYPE_CODE_HALF:     // HALF
       ResultTy = Type::getHalfTy(Context);
       break;
@@ -3138,7 +3144,13 @@ Error BitcodeReader::parseConstants() {
         return error("Invalid float const record");
 
       auto *ScalarTy = CurTy->getScalarType();
-      if (ScalarTy->isHalfTy())
+      if (ScalarTy->isFloat8E4M3FNTy())
+        V = ConstantFP::get(Context, APFloat(APFloat::Float8E4M3FN(),
+                                             APInt(8, (uint8_t)Record[0])));
+      else if (ScalarTy->isFloat8E5M2Ty())
+        V = ConstantFP::get(Context, APFloat(APFloat::Float8E5M2(),
+                                             APInt(8, (uint8_t)Record[0])));
+      else if (ScalarTy->isHalfTy())
         V = ConstantFP::get(CurTy, APFloat(APFloat::IEEEhalf(),
                                            APInt(16, (uint16_t)Record[0])));
       else if (ScalarTy->isBFloatTy())
@@ -3234,6 +3246,18 @@ Error BitcodeReader::parseConstants() {
           V = ConstantDataVector::get(Context, Elts);
         else
           V = ConstantDataArray::get(Context, Elts);
+      } else if (EltTy->isFloat8E4M3FNTy()) {
+        SmallVector<uint8_t, 16> Elts(Record.begin(), Record.end());
+        if (isa<VectorType>(CurTy))
+          V = ConstantDataVector::getFP(EltTy, Elts);
+        else
+          V = ConstantDataArray::getFP(EltTy, Elts);
+      } else if (EltTy->isFloat8E5M2Ty()) {
+        SmallVector<uint8_t, 16> Elts(Record.begin(), Record.end());
+        if (isa<VectorType>(CurTy))
+          V = ConstantDataVector::getFP(EltTy, Elts);
+        else
+          V = ConstantDataArray::getFP(EltTy, Elts);
       } else if (EltTy->isHalfTy()) {
         SmallVector<uint16_t, 16> Elts(Record.begin(), Record.end());
         if (isa<VectorType>(CurTy))
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 6d01e3b4d82189..46b5d0ed9440ee 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -1043,19 +1043,21 @@ void ModuleBitcodeWriter::writeTypeTable() {
     unsigned Code = 0;
 
     switch (T->getTypeID()) {
-    case Type::VoidTyID:      Code = bitc::TYPE_CODE_VOID;      break;
-    case Type::HalfTyID:      Code = bitc::TYPE_CODE_HALF;      break;
-    case Type::BFloatTyID:    Code = bitc::TYPE_CODE_BFLOAT;    break;
-    case Type::FloatTyID:     Code = bitc::TYPE_CODE_FLOAT;     break;
-    case Type::DoubleTyID:    Code = bitc::TYPE_CODE_DOUBLE;    break;
-    case Type::X86_FP80TyID:  Code = bitc::TYPE_CODE_X86_FP80;  break;
-    case Type::FP128TyID:     Code = bitc::TYPE_CODE_FP128;     break;
-    case Type::PPC_FP128TyID: Code = bitc::TYPE_CODE_PPC_FP128; break;
-    case Type::LabelTyID:     Code = bitc::TYPE_CODE_LABEL;     break;
-    case Type::MetadataTyID:  Code = bitc::TYPE_CODE_METADATA;  break;
-    case Type::X86_MMXTyID:   Code = bitc::TYPE_CODE_X86_MMX;   break;
-    case Type::X86_AMXTyID:   Code = bitc::TYPE_CODE_X86_AMX;   break;
-    case Type::TokenTyID:     Code = bitc::TYPE_CODE_TOKEN;     break;
+    case Type::VoidTyID:          Code = bitc::TYPE_CODE_VOID;          break;
+    case Type::Float8E4M3FNTyID:  Code = bitc::TYPE_CODE_Float8E4M3FN;  break;
+    case Type::Float8E5M2TyID:    Code = bitc::TYPE_CODE_Float8E5M2;    break;
+    case Type::HalfTyID:          Code = bitc::TYPE_CODE_HALF;          break;
+    case Type::BFloatTyID:        Code = bitc::TYPE_CODE_BFLOAT;        break;
+    case Type::FloatTyID:         Code = bitc::TYPE_CODE_FLOAT;         break;
+    case Type::DoubleTyID:        Code = bitc::TYPE_CODE_DOUBLE;        break;
+    case Type::X86_FP80TyID:      Code = bitc::TYPE_CODE_X86_FP80;      break;
+    case Type::FP128TyID:         Code = bitc::TYPE_CODE_FP128;         break;
+    case Type::PPC_FP128TyID:     Code = bitc::TYPE_CODE_PPC_FP128;     break;
+    case Type::LabelTyID:         Code = bitc::TYPE_CODE_LABEL;         break;
+    case Type::MetadataTyID:      Code = bitc::TYPE_CODE_METADATA;      break;
+    case Type::X86_MMXTyID:       Code = bitc::TYPE_CODE_X86_MMX;       break;
+    case Type::X86_AMXTyID:       Code = bitc::TYPE_CODE_X86_AMX;       break;
+    case Type::TokenTyID:         Code = bitc::TYPE_CODE_TOKEN;         break;
     case Type::IntegerTyID:
       // INTEGER: [width]
       Code = bitc::TYPE_CODE_INTEGER;
@@ -2671,8 +2673,8 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
     } else if (const ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
       Code = bitc::CST_CODE_FLOAT;
       Type *Ty = CFP->getType()->getScalarType();
-      if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() ||
-          Ty->isDoubleTy()) {
+      if (Ty->isFloat8E4M3FNTy() || Ty->isFloat8E5M2Ty() || Ty->isHalfTy() ||
+          Ty->isBFloatTy() || Ty->isFloatTy() || Ty->isDoubleTy()) {
         Record.push_back(CFP->getValueAPF().bitcastToAPInt().getZExtValue());
       } else if (Ty->isX86_FP80Ty()) {
         // api needed to prevent premature destruction
diff --git a/llvm/lib/CodeGen/MIRParser/MILexer.cpp b/llvm/lib/CodeGen/MIRParser/MILexer.cpp
index 7bb21655320474..5ff339ddbe829f 100644
--- a/llvm/lib/CodeGen/MIRParser/MILexer.cpp
+++ b/llvm/lib/CodeGen/MIRParser/MILexer.cpp
@@ -548,7 +548,7 @@ static Cursor maybeLexMCSymbol(Cursor C, MIToken &Token,
 }
 
 static bool isValidHexFloatingPointPrefix(char C) {
-  return C == 'H' || C == 'K' || C == 'L' || C == 'M' || C == 'R';
+  return C == 'H' || C == 'K' || C == 'L' || C == 'M' || C == 'R' || C == 'Q' || C == 'S';
 }
 
 static Cursor lexFloatingPointLiteral(Cursor Range, Cursor C, MIToken &Token) {
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 941f6a7a7d8232..9c4c481dfb25af 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -561,19 +561,21 @@ void TypePrinting::incorporateTypes() {
 /// names or up references to shorten the type name where possible.
 void TypePrinting::print(Type *Ty, raw_ostream &OS) {
   switch (Ty->getTypeID()) {
-  case Type::VoidTyID:      OS << "void"; return;
-  case Type::HalfTyID:      OS << "half"; return;
-  case Type::BFloatTyID:    OS << "bfloat"; return;
-  case Type::FloatTyID:     OS << "float"; return;
-  case Type::DoubleTyID:    OS << "double"; return;
-  case Type::X86_FP80TyID:  OS << "x86_fp80"; return;
-  case Type::FP128TyID:     OS << "fp128"; return;
-  case Type::PPC_FP128TyID: OS << "ppc_fp128"; return;
-  case Type::LabelTyID:     OS << "label"; return;
-  case Type::MetadataTyID:  OS << "metadata"; return;
-  case Type::X86_MMXTyID:   OS << "x86_mmx"; return;
-  case Type::X86_AMXTyID:   OS << "x86_amx"; return;
-  case Type::TokenTyID:     OS << "token"; return;
+  case Type::VoidTyID:         OS << "void"; return;
+  case Type::Float8E4M3FNTyID: OS << "float8e4m3fn"; return;
+  case Type::Float8E5M2TyID:   OS << "float8e5m2"; return;
+  case Type::HalfTyID:         OS << "half"; return;
+  case Type::BFloatTyID:       OS << "bfloat"; return;
+  case Type::FloatTyID:        OS << "float"; return;
+  case Type::DoubleTyID:       OS << "double"; return;
+  case Type::X86_FP80TyID:     OS << "x86_fp80"; return;
+  case Type::FP128TyID:        OS << "fp128"; return;
+  case Type::PPC_FP128TyID:    OS << "ppc_fp128"; return;
+  case Type::LabelTyID:        OS << "label"; return;
+  case Type::MetadataTyID:     OS << "metadata"; return;
+  case Type::X86_MMXTyID:      OS << "x86_mmx"; return;
+  case Type::X86_AMXTyID:      OS << "x86_amx"; return;
+  case Type::TokenTyID:        OS << "token"; return;
   case Type::IntegerTyID:
     OS << 'i' << cast<IntegerType>(Ty)->getBitWidth();
     return;
@@ -1521,7 +1523,16 @@ static void WriteAPFloatInternal(raw_ostream &Out, const APFloat &APF) {
     Out << 'R';
     Out << format_hex_no_prefix(API.getZExtValue(), 4,
                                 /*Upper=*/true);
-  } else
+  } else if (&APF.getSemantics() == &APFloat::Float8E4M3FN()) {
+    Out << 'Q';
+    Out << format_hex_no_prefix(API.getZExtValue(), 2,
+                                /*Upper=*/true);
+  } else if (&APF.getSemantics() == &APFloat::Float8E5M2()) {
+    Out << 'S';
+    Out << format_hex_no_prefix(API.getZExtValue(), 2,
+                                /*Upper=*/true);
+  }
+    else
     llvm_unreachable("Unsupported floating point type");
 }
 
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index a5fb497f54ed15..557685bd34ce64 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -371,6 +371,8 @@ Constant *Constant::getNullValue(Type *Ty) {
   switch (Ty->getTypeID()) {
   case Type::IntegerTyID:
     return ConstantInt::get(Ty, 0);
+  case Type::Float8E4M3FNTyID:
+  case Type::Float8E5M2TyID:
   case Type::HalfTyID:
   case Type::BFloatTyID:
   case Type::FloatTyID:
@@ -1255,6 +1257,8 @@ static Constant *getSequenceIfElementsMatch(Constant *C,
     else if (CI->getType()->isIntegerTy(64))
       return getIntSequenceIfElementsMatch<SequenceTy, uint64_t>(V);
   } else if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
+    if (CFP->getType()->isFloat8E4M3FNTy() || CFP->getType()->isFloat8E5M2Ty())
+      return getFPSequenceIfElementsMatch<SequenceTy, uint8_t>(V);
     if (CFP->getType()->isHalfTy() || CFP->getType()->isBFloatTy())
       return getFPSequenceIfElementsMatch<SequenceTy, uint16_t>(V);
     else if (CFP->getType()->isFloatTy())
@@ -1608,6 +1612,18 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat& Val) {
     return false;         // These can't be represented as floating point!
 
   // FIXME rounding mode needs to be more flexible
+  case Type::Float8E4M3FNTyID: {
+    if (&Val2.getSemantics() == &APFloat::Float8E4M3FN())
+      return true;
+    Val2.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven, &losesInfo);
+    return !losesInfo;
+  }
+  case Type::Float8E5M2TyID: {
+    if (&Val2.getSemantics() == &APFloat::Float8E5M2())
+      return true;
+    Val2.convert(APFloat::Float8E5M2(), APFloat::rmNearestTiesToEven, &losesInfo);
+    return !losesInfo;
+  }
   case Type::HalfTyID: {
     if (&Val2.getSemantics() == &APFloat::IEEEhalf())
       return true;
@@ -1627,7 +1643,9 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat& Val) {
     return !losesInfo;
   }
   case Type::DoubleTyID: {
-    if (&Val2.getSemantics() == &APFloat::IEEEhalf() ||
+    if (&Val2.getSemantics() == &APFloat::Float8E5M2() ||
+        &Val2.getSemantics() == &APFloat::Float8E4M3FN() ||
+        &Val2.getSemantics() == &APFloat::IEEEhalf() ||
         &Val2.getSemantics() == &APFloat::BFloat() ||
         &Val2.getSemantics() == &APFloat::IEEEsingle() ||
         &Val2.getSemantics() == &APFloat::IEEEdouble())
@@ -1636,19 +1654,25 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat& Val) {
     return !losesInfo;
   }
   case Type::X86_FP80TyID:
-    return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+    return &Val2.getSemantics() == &APFloat::Float8E5M2() ||
+           &Val2.getSemantics() == &APFloat::Float8E4M3FN() ||
+           &Val2.getSemantics() == &APFloat::IEEEhalf() ||
            &Val2.getSemantics() == &APFloat::BFloat() ||
            &Val2.getSemantics() == &APFloat::IEEEsingle() ||
            &Val2.getSemantics() == &APFloat::IEEEdouble() ||
            &Val2.getSemantics() == &APFloat::x87DoubleExtended();
   case Type::FP128TyID:
-    return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+    return &Val2.getSemantics() == &APFloat::Float8E5M2() ||
+           &Val2.getSemantics() == &APFloat::Float8E4M3FN() ||
+           &Val2.getSemantics() == &APFloat::IEEEhalf() ||
            &Val2.getSemantics() == &APFloat::BFloat() ||
            &Val2.getSemantics() == &APFloat::IEEEsingle() ||
            &Val2.getSemantics() == &APFloat::IEEEdouble() ||
            &Val2.getSemantics() == &APFloat::IEEEquad();
   case Type::PPC_FP128TyID:
-    return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+    return &Val2.getSemantics() == &APFloat::Float8E5M2() ||
+           &Val2.getSemantics() == &APFloat::Float8E4M3FN() ||
+           &Val2.getSemantics() == &APFloat::IEEEhalf() ||
            &Val2.getSemantics() == &APFloat::BFloat() ||
            &Val2.getSemantics() == &APFloat::IEEEsingle() ||
            &Val2.getSemantics() == &APFloat::IEEEdouble() ||
@@ -2730,7 +2754,8 @@ StringRef ConstantDataSequential::getRawDataValues() const {
 }
 
 bool ConstantDataSequential::isElementTypeCompatible(Type *Ty) {
-  if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() || Ty->isDoubleTy())
+  if (Ty->isFloat8E4M3FNTy() || Ty->isFloat8E5M2Ty() || Ty->isHalfTy() ||
+      Ty->isBFloatTy() || Ty->isFloatTy() || Ty->isDoubleTy())
     return true;
   if (auto *IT = dyn_cast<IntegerType>(Ty)) {
     switch (IT->getBitWidth()) {
@@ -2855,8 +2880,16 @@ void ConstantDataSequential::destroyConstantImpl() {
 /// element type taken from argument `ElementType', and count taken from
 /// argument `Elts'.  The amount of bits of the contained type must match the
 /// number of bits of the type contained in the passed in ArrayRef.
-/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
-/// that this can return a ConstantAggregateZero object.
+/// (i.e. float8e4m3fn or float8e5m2 or half or bfloat for 16bits, float for
+/// 32bits, double for 64bits) Note that this can return a ConstantAggregateZero
+/// object.
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint8_t> Elts) {
+  assert((ElementType->isFloat8E4M3FNTy() || ElementType->isFloat8E5M2Ty()) &&
+         "Element type is not a 8-bit float type");
+  Type *Ty = ArrayType::get(ElementType, Elts.size());
+  const char *Data = reinterpret_cast<const char *>(Elts.data());
+  return getImpl(StringRef(Data, Elts.size() * 1), Ty);
+}
 Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint16_t> Elts) {
   assert((ElementType->isHalfTy() || ElementType->isBFloatTy()) &&
          "Element type is not a 16-bit float type");
@@ -2929,8 +2962,16 @@ Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<double> Elts) {
 /// element type taken from argument `ElementType', and count taken from
 /// argument `Elts'.  The amount of bits of the contained type must match the
 /// number of bits of the type contained in the passed in ArrayRef.
-/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
-/// that this can return a ConstantAggregateZero object.
+/// (i.e. float8e4m3 or float8e5m2 or half or bfloat for 16bits, float for 32bits,
+/// double for 64bits) Note that this can return a ConstantAggregateZero object.
+Constant *ConstantDataVector::getFP(Type *ElementType,
+                                    ArrayRef<uint8_t> Elts) {
+  assert((ElementType->isFloat8E4M3FNTy() || ElementType->isFloat8E5M2Ty()) &&
+         "Element type is not a 8-bit float type");
+  auto *Ty = FixedVectorType::get(ElementType, Elts.size());
+  const char *Data = reinterpret_cast<const char *>(Elts.data());
+  return getImpl(StringRef(Data, Elts.size() * 1), Ty);
+}
 Constant *ConstantDataVector::getFP(Type *ElementType,
                                     ArrayRef<uint16_t> Elts) {
   assert((ElementType->isHalfTy() || ElementType->isBFloatTy()) &&
@@ -2977,6 +3018,16 @@ Constant *ConstantDataVector::getSplat(unsigned NumElts, Constant *V) {
   }
 
   if (ConstantFP *CFP = dyn_cast<ConstantFP>(V)) {
+    if (CFP->getType()->isFloat8E4M3FNTy()) {
+      SmallVector<uint8_t, 16> Elts(
+          NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
+      return getFP(V->getType(), Elts);
+    }
+    if (CFP->getType()->isFloat8E5M2Ty()) {
+      SmallVector<uint8_t, 16> Elts(
+          NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
+      return getFP(V->getType(), Elts);
+    }
     if (CFP->getType()->isHalfTy()) {
       SmallVector<uint16_t, 16> Elts(
           NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
@@ -3056,6 +3107,14 @@ APFloat ConstantDataSequential::getElementAsAPFloat(unsigned Elt) const {
   switch (getElementType()->getTypeID()) {
   default:
     llvm_unreachable("Accessor can only be used when element is float/double!");
+  case Type::Float8E4M3FNTyID: {
+    auto EltVal = *reinterpret_cast<const uint8_t *>(EltPtr);
+    return APFloat(APFloat::Float8E4M3FN(), APInt(8, EltVal));
+  }
+  case Type::Float8E5M2TyID: {
+    auto EltVal = *reinterpret_cast<const uint8_t *>(EltPtr);
+    return APFloat(APFloat::Float8E5M2(), APInt(8, EltVal));
+  }
   case Type::HalfTyID: {
     auto EltVal = *reinterpret_cast<const uint16_t *>(EltPtr);
     return APFloat(APFloat::IEEEhalf(), APInt(16, EltVal));
@@ -3088,7 +3147,8 @@ double ConstantDataSequential::getElementAsDouble(unsigned Elt) const {
 }
 
 Constant *ConstantDataSequential::getElementAsConstant(unsigned Elt) const {
-  if (getElementType()->isHalfTy() || getElementType()->isBFloatTy() ||
+  if (getElementType()->isFloat8E4M3FNTy() || getElementType()->isFloat8E5M2Ty() ||
+      getElementType()->isHalfTy() || getElementType()->isBFloatTy() ||
       getElementType()->isFloatTy() || getElementType()->isDoubleTy())
     return ConstantFP::get(getContext(), getElementAsAPFloat(Elt));
 
diff --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp
index 6aff94f39d9c0c..f7be166a6de665 100644
--- a/llvm/lib/IR/Core.cpp
+++ b/llvm/lib/IR/Core.cpp
@@ -559,6 +559,10 @@ LLVMTypeKind LLVMGetTypeKind(LLVMTypeRef Ty) {
   switch (unwrap(Ty)->getTypeID()) {
   case Type::VoidTyID:
     return LLVMVoidTypeKind;
+  case Type::Float8E4M3FNTyID:
+    return LLVMFloat8E4M3FNTypeKind;
+  case Type::Float8E5M2TyID:
+    return LLVMFloat8E5M2TypeKind;
   case Type::HalfTyID:
     return LLVMHalfTypeKind;
   case Type::BFloatTyID:
@@ -683,7 +687,12 @@ unsigned LLVMGetIntTypeWidth(LLVMTypeRef IntegerTy) {
 }
 
 /*--.. Operations on real types ............................................--*/
-
+LLVMTypeRef LLVMFloat8E4M3FNTypeInContext(LLVMContextRef C) {
+  return (LLVMTypeRef)Type::getFloat8E4M3FNTy(*unwrap(C));
+}
+LLVMTypeRef LLVMFloat8E5M2TypeInContext(LLVMContextRef C) {
+  return (LLVMTypeRef)Type::getFloat8E5M2Ty(*unwrap(C));
+}
 LLVMTypeRef LLVMHalfTypeInContext(LLVMContextRef C) {
   return (LLVMTypeRef) Type::getHalfTy(*unwrap(C));
 }
@@ -712,6 +721,12 @@ LLVMTypeRef LLVMX86AMXTypeInContext(LLVMContextRef C) {
   return (LLVMTypeRef) Type::getX86_AMXTy(*unwrap(C));
 }
 
+LLVMTypeRef LLVMFloat8E4M3FNType(void) {
+  return LLVMFloat8E4M3FNTypeInContext(LLVMGetGlobalContext());
+}
+LLVMTypeRef LLVMFloat8E5M2Type(void) {
+  return LLVMFloat8E5M2TypeInContext(LLVMGetGlobalContext());
+}
 LLVMTypeRef LLVMHalfType(void) {
   return LLVMHalfTypeInContext(LLVMGetGlobalContext());
 }
@@ -1520,8 +1535,8 @@ double LLVMConstRealGetDouble(LLVMValueRef ConstantVal, LLVMBool *LosesInfo) {
   ConstantFP *cFP = unwrap<ConstantFP>(ConstantVal) ;
   Type *Ty = cFP->getType();
 
-  if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() ||
-      Ty->isDoubleTy()) {
+  if (Ty->isFloat8E4M3FNTy() || Ty->isFloat8E5M2Ty() || Ty->isHalfTy() ||
+      Ty->isBFloatTy() || Ty->isFloatTy() || Ty->isDoubleTy()) {
     *LosesInfo = false;
     return cFP->getValueAPF().convertToDouble();
   }
diff --git a/llvm/lib/IR/DataLayout.cpp b/llvm/lib/IR/DataLayout.cpp
index 27411653324874..9d7a88e3399007 100644
--- a/llvm/lib/IR/DataLayout.cpp
+++ b/llvm/lib/IR/DataLayout.cpp
@@ -184,6 +184,7 @@ static const std::pair<AlignTypeEnum, LayoutAlignElem> DefaultAlignments[] = {
     {INTEGER_ALIGN, {16, Align(2), Align(2)}},   // i16
     {INTEGER_ALIGN, {32, Align(4), Align(4)}},   // i32
     {INTEGER_ALIGN, {64, Align(4), Align(8)}},   // i64
+    {FLOAT_ALIGN, {8, Align(1), Align(1)}},     // float8e4m3fn, float8e5m2
     {FLOAT_ALIGN, {16, Align(2), Align(2)}},     // half, bfloat
     {FLOAT_ALIGN, {32, Align(4), Align(4)}},     // float
     {FLOAT_ALIGN, {64, Align(8), Align(8)}},     // double
@@ -813,6 +814,8 @@ Align DataLayout::getAlignment(Type *Ty, bool abi_or_pref) const {
   }
   case Type::IntegerTyID:
     return getIntegerAlignment(Ty->getIntegerBitWidth(), abi_or_pref);
+  case Type::Float8E4M3FNTyID:
+  case Type::Float8E5M2TyID:
   case Type::HalfTyID:
   case Type::BFloatTyID:
   case Type::FloatTyID:
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index 96953ac49c19b4..82dc9dc4c09cda 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1004,17 +1004,19 @@ static std::string getMangledTypeStr(Type *Ty, bool &HasUnnamedType) {
   } else if (Ty) {
     switch (Ty->getTypeID()) {
     default: llvm_unreachable("Unhandled type");
-    case Type::VoidTyID:      Result += "isVoid";   break;
-    case Type::MetadataTyID:  Result += "Metadata"; break;
-    case Type::HalfTyID:      Result += "f16";      break;
-    case Type::BFloatTyID:    Result += "bf16";     break;
-    case Type::FloatTyID:     Result += "f32";      break;
-    case Type::DoubleTyID:    Result += "f64";      break;
-    case Type::X86_FP80TyID:  Result += "f80";      break;
-    case Type::FP128TyID:     Result += "f128";     break;
-    case Type::PPC_FP128TyID: Result += "ppcf128";  break;
-    case Type::X86_MMXTyID:   Result += "x86mmx";   break;
-    case Type::X86_AMXTyID:   Result += "x86amx";   break;
+    case Type::VoidTyID:          Result += "isVoid";     break;
+    case Type::MetadataTyID:      Result += "Metadata";   break;
+    case Type::Float8E4M3FNTyID:  Result += "f8e4m3fn";   break;
+    case Type::Float8E5M2TyID:    Result += "f8e5m2";     break;
+    case Type::HalfTyID:          Result += "f16";        break;
+    case Type::BFloatTyID:        Result += "bf16";       break;
+    case Type::FloatTyID:         Result += "f32";        break;
+    case Type::DoubleTyID:        Result += "f64";        break;
+    case Type::X86_FP80TyID:      Result += "f80";        break;
+    case Type::FP128TyID:         Result += "f128";       break;
+    case Type::PPC_FP128TyID:     Result += "ppcf128";    break;
+    case Type::X86_MMXTyID:       Result += "x86mmx";     break;
+    case Type::X86_AMXTyID:       Result += "x86amx";     break;
     case Type::IntegerTyID:
       Result += "i" + utostr(cast<IntegerType>(Ty)->getBitWidth());
       break;
diff --git a/llvm/lib/IR/LLVMContextImpl.cpp b/llvm/lib/IR/LLVMContextImpl.cpp
index 72fedd8d673980..673e0173c30514 100644
--- a/llvm/lib/IR/LLVMContextImpl.cpp
+++ b/llvm/lib/IR/LLVMContextImpl.cpp
@@ -36,6 +36,7 @@ using namespace llvm;
 LLVMContextImpl::LLVMContextImpl(LLVMContext &C)
     : DiagHandler(std::make_unique<DiagnosticHandler>()),
       VoidTy(C, Type::VoidTyID), LabelTy(C, Type::LabelTyID),
+      Float8E4M3FNTy(C, Type::Float8E4M3FNTyID), Float8E5M2Ty(C, Type::Float8E5M2TyID),
       HalfTy(C, Type::HalfTyID), BFloatTy(C, Type::BFloatTyID),
       FloatTy(C, Type::FloatTyID), DoubleTy(C, Type::DoubleTyID),
       MetadataTy(C, Type::MetadataTyID), TokenTy(C, Type::TokenTyID),
diff --git a/llvm/lib/IR/LLVMContextImpl.h b/llvm/lib/IR/LLVMContextImpl.h
index 7c67e191348eaf..e839b96606ad3f 100644
--- a/llvm/lib/IR/LLVMContextImpl.h
+++ b/llvm/lib/IR/LLVMContextImpl.h
@@ -1559,8 +1559,8 @@ class LLVMContextImpl {
   ConstantInt *TheFalseVal = nullptr;
 
   // Basic type instances.
-  Type VoidTy, LabelTy, HalfTy, BFloatTy, FloatTy, DoubleTy, MetadataTy,
-      TokenTy;
+  Type VoidTy, LabelTy, Float8E5M2Ty, Float8E4M3FNTy, HalfTy, BFloatTy, FloatTy,
+       DoubleTy, MetadataTy, TokenTy;
   Type X86_FP80Ty, FP128Ty, PPC_FP128Ty, X86_MMXTy, X86_AMXTy;
   IntegerType Int1Ty, Int8Ty, Int16Ty, Int32Ty, Int64Ty, Int128Ty;
 
diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp
index c59bc3622fde5e..8464e7f565c0b5 100644
--- a/llvm/lib/IR/Type.cpp
+++ b/llvm/lib/IR/Type.cpp
@@ -35,19 +35,21 @@ using namespace llvm;
 
 Type *Type::getPrimitiveType(LLVMContext &C, TypeID IDNumber) {
   switch (IDNumber) {
-  case VoidTyID      : return getVoidTy(C);
-  case HalfTyID      : return getHalfTy(C);
-  case BFloatTyID    : return getBFloatTy(C);
-  case FloatTyID     : return getFloatTy(C);
-  case DoubleTyID    : return getDoubleTy(C);
-  case X86_FP80TyID  : return getX86_FP80Ty(C);
-  case FP128TyID     : return getFP128Ty(C);
-  case PPC_FP128TyID : return getPPC_FP128Ty(C);
-  case LabelTyID     : return getLabelTy(C);
-  case MetadataTyID  : return getMetadataTy(C);
-  case X86_MMXTyID   : return getX86_MMXTy(C);
-  case X86_AMXTyID   : return getX86_AMXTy(C);
-  case TokenTyID     : return getTokenTy(C);
+  case VoidTyID          : return getVoidTy(C);
+  case Float8E4M3FNTyID  : return getFloat8E4M3FNTy(C);
+  case Float8E5M2TyID    : return getFloat8E5M2Ty(C);
+  case HalfTyID          : return getHalfTy(C);
+  case BFloatTyID        : return getBFloatTy(C);
+  case FloatTyID         : return getFloatTy(C);
+  case DoubleTyID        : return getDoubleTy(C);
+  case X86_FP80TyID      : return getX86_FP80Ty(C);
+  case FP128TyID         : return getFP128Ty(C);
+  case PPC_FP128TyID     : return getPPC_FP128Ty(C);
+  case LabelTyID         : return getLabelTy(C);
+  case MetadataTyID      : return getMetadataTy(C);
+  case X86_MMXTyID       : return getX86_MMXTy(C);
+  case X86_AMXTyID       : return getX86_AMXTy(C);
+  case TokenTyID         : return getTokenTy(C);
   default:
     return nullptr;
   }
@@ -69,6 +71,8 @@ bool Type::isScalableTy() const {
 
 const fltSemantics &Type::getFltSemantics() const {
   switch (getTypeID()) {
+  case Float8E5M2TyID: return APFloat::Float8E5M2();
+  case Float8E4M3FNTyID: return APFloat::Float8E4M3FN();
   case HalfTyID: return APFloat::IEEEhalf();
   case BFloatTyID: return APFloat::BFloat();
   case FloatTyID: return APFloat::IEEEsingle();
@@ -92,7 +96,11 @@ bool Type::isScalableTargetExtTy() const {
 
 Type *Type::getFloatingPointTy(LLVMContext &C, const fltSemantics &S) {
   Type *Ty;
-  if (&S == &APFloat::IEEEhalf())
+  if (&S == &APFloat::Float8E4M3FN())
+    Ty = Type::getFloat8E4M3FNTy(C);
+  else if (&S == &APFloat::Float8E5M2())
+    Ty = Type::getFloat8E5M2Ty(C);
+  else if (&S == &APFloat::IEEEhalf())
     Ty = Type::getHalfTy(C);
   else if (&S == &APFloat::BFloat())
     Ty = Type::getBFloatTy(C);
@@ -165,6 +173,10 @@ bool Type::isEmptyTy() const {
 
 TypeSize Type::getPrimitiveSizeInBits() const {
   switch (getTypeID()) {
+  case Type::Float8E4M3FNTyID:
+    return TypeSize::getFixed(8);
+  case Type::Float8E5M2TyID:
+    return TypeSize::getFixed(8);
   case Type::HalfTyID:
     return TypeSize::getFixed(16);
   case Type::BFloatTyID:
@@ -207,6 +219,8 @@ int Type::getFPMantissaWidth() const {
   if (auto *VTy = dyn_cast<VectorType>(this))
     return VTy->getElementType()->getFPMantissaWidth();
   assert(isFloatingPointTy() && "Not a floating point type!");
+  if (getTypeID() == Float8E4M3FNTyID) return 3;
+  if (getTypeID() == Float8E5M2TyID) return 2;
   if (getTypeID() == HalfTyID) return 11;
   if (getTypeID() == BFloatTyID) return 8;
   if (getTypeID() == FloatTyID) return 24;
@@ -236,6 +250,8 @@ bool Type::isSizedDerivedType(SmallPtrSetImpl<Type*> *Visited) const {
 
 Type *Type::getVoidTy(LLVMContext &C) { return &C.pImpl->VoidTy; }
 Type *Type::getLabelTy(LLVMContext &C) { return &C.pImpl->LabelTy; }
+Type *Type::getFloat8E5M2Ty(LLVMContext &C) { return &C.pImpl->Float8E5M2Ty; }
+Type *Type::getFloat8E4M3FNTy(LLVMContext &C) { return &C.pImpl->Float8E4M3FNTy; }
 Type *Type::getHalfTy(LLVMContext &C) { return &C.pImpl->HalfTy; }
 Type *Type::getBFloatTy(LLVMContext &C) { return &C.pImpl->BFloatTy; }
 Type *Type::getFloatTy(LLVMContext &C) { return &C.pImpl->FloatTy; }
diff --git a/llvm/test/Assembler/float8.ll b/llvm/test/Assembler/float8.ll
new file mode 100644
index 00000000000000..1fa5c5564e4af0
--- /dev/null
+++ b/llvm/test/Assembler/float8.ll
@@ -0,0 +1,101 @@
+; RUN: llvm-as < %s | llvm-dis | FileCheck %s --check-prefix=ASSEM-DISASS
+; RUN: opt < %s -O3 -S | FileCheck %s --check-prefix=OPT
+; RUN: verify-uselistorder %s
+
+define float8e4m3fn @check_float8e4m3fn(float8e4m3fn %A) {
+; ASSEM-DISASS: ret float8e4m3fn %A
+    ret float8e4m3fn %A
+}
+
+define float8e5m2 @check_float8e5m2(float8e5m2 %A) {
+; ASSEM-DISASS: ret float8e5m2 %A
+    ret float8e5m2 %A
+}
+
+define float8e4m3fn @check_float8e4m3fn_literal() {
+; ASSEM-DISASS: ret float8e4m3fn 0xQ31
+    ret float8e4m3fn 0xQ31
+}
+
+define float8e5m2 @check_float8e5m2_literal() {
+; ASSEM-DISASS: ret float8e5m2 0xS31
+    ret float8e5m2 0xS31
+}
+
+define <4 x float8e4m3fn> @check_float8e4m3fn_fixed_vector() {
+; ASSEM-DISASS: ret <4 x float8e4m3fn> %tmp
+  %tmp = fadd <4 x float8e4m3fn> undef, undef
+  ret <4 x float8e4m3fn> %tmp
+}
+
+define <4 x float8e5m2> @check_float8e5m2_fixed_vector() {
+; ASSEM-DISASS: ret <4 x float8e5m2> %tmp
+  %tmp = fadd <4 x float8e5m2> undef, undef
+  ret <4 x float8e5m2> %tmp
+}
+
+define <vscale x 4 x float8e4m3fn> @check_float8e4m3fn_vector() {
+; ASSEM-DISASS: ret <vscale x 4 x float8e4m3fn> %tmp
+  %tmp = fadd <vscale x 4 x float8e4m3fn> undef, undef
+  ret <vscale x 4 x float8e4m3fn> %tmp
+}
+
+define <vscale x 4 x float8e5m2> @check_float8e5m2_vector() {
+; ASSEM-DISASS: ret <vscale x 4 x float8e5m2> %tmp
+  %tmp = fadd <vscale x 4 x float8e5m2> undef, undef
+  ret <vscale x 4 x float8e5m2> %tmp
+}
+
+define float8e4m3fn @check_float8e4m3fn_constprop() {
+  %tmp = fadd float8e4m3fn 0xQ40, 0xQ40
+; OPT: 0xQ48
+  ret float8e4m3fn %tmp
+}
+
+define float8e5m2 @check_float8e5m2_constprop() {
+  %tmp = fadd float8e5m2 0xS40, 0xS40
+; OPT: 0xS44
+  ret float8e5m2 %tmp
+}
+
+define float @check_float8e4m3fn_convert() {
+  %tmp = fpext float8e4m3fn 0xQ40 to float
+; OPT: 2.000000e+00
+  ret float %tmp
+}
+
+define float @check_float8e5m2_convert() {
+  %tmp = fpext float8e5m2 0xS40 to float
+; OPT: 2.000000e+00
+  ret float %tmp
+}
+
+; ASSEM-DISASS-LABEL @snan_float8e5m2
+define float8e5m2 @snan_loat8e5m2() {
+; ASSEM-DISASS: ret float8e5m2 0xS7D
+    ret float8e5m2 0xS7D
+}
+
+; ASSEM-DISASS-LABEL @first_qnan_float8e5m2
+define float8e5m2 @first_qnan_float8e5m2() {
+; ASSEM-DISASS: ret float8e5m2 0xS7E
+    ret float8e5m2 0xS7E
+}
+
+; ASSEM-DISASS-LABEL @second_qnan_float8e5m2
+define float8e5m2 @second_qnan_float8e5m2() {
+; ASSEM-DISASS: ret float8e5m2 0xS7F
+    ret float8e5m2 0xS7F
+}
+
+; ASSEM-DISASS-LABEL @inf_float8e5m2
+define float8e5m2 @inf_float8e5m2() {
+; ASSEM-DISASS: ret float8e5m2 0xS7C
+    ret float8e5m2 0xS7C
+}
+
+; ASSEM-DISASS-LABEL @qnan_float8e4m3fn
+define float8e4m3fn @first_qnan_float8e4m3fn() {
+; ASSEM-DISASS: ret float8e4m3fn 0xQ7F
+    ret float8e4m3fn 0xQ7F
+}
\ No newline at end of file
diff --git a/llvm/tools/llvm-c-test/echo.cpp b/llvm/tools/llvm-c-test/echo.cpp
index 347863638849ce..7d3316372b0744 100644
--- a/llvm/tools/llvm-c-test/echo.cpp
+++ b/llvm/tools/llvm-c-test/echo.cpp
@@ -73,10 +73,14 @@ struct TypeCloner {
     switch (Kind) {
       case LLVMVoidTypeKind:
         return LLVMVoidTypeInContext(Ctx);
+      case LLVMFloat8E5M2TypeKind:
+        return LLVMFloat8E5M2TypeInContext(Ctx);
+      case LLVMFloat8E4M3FNTypeKind:
+        return LLVMFloat8E4M3FNTypeInContext(Ctx);
       case LLVMHalfTypeKind:
         return LLVMHalfTypeInContext(Ctx);
       case LLVMBFloatTypeKind:
-        return LLVMHalfTypeInContext(Ctx);
+        return LLVMBFloatTypeInContext(Ctx);
       case LLVMFloatTypeKind:
         return LLVMFloatTypeInContext(Ctx);
       case LLVMDoubleTypeKind:



More information about the llvm-commits mailing list