[clang] 8c24f33 - [IR][BFloat] Add BFloat IR type

Ties Stuij via cfe-commits cfe-commits at lists.llvm.org
Fri May 15 06:44:09 PDT 2020


Author: Ties Stuij
Date: 2020-05-15T14:43:43+01:00
New Revision: 8c24f33158d81d5f4b0c5d27c2f07396f0f1484b

URL: https://github.com/llvm/llvm-project/commit/8c24f33158d81d5f4b0c5d27c2f07396f0f1484b
DIFF: https://github.com/llvm/llvm-project/commit/8c24f33158d81d5f4b0c5d27c2f07396f0f1484b.diff

LOG: [IR][BFloat] Add BFloat IR type

Summary:
The BFloat IR type is introduced to provide support for, initially, the BFloat16
datatype introduced with the Armv8.6 architecture (optional from Armv8.2
onwards). It has an 8-bit exponent and a 7-bit mantissa and behaves like an IEEE
754 floating point IR type.

This is part of a patch series upstreaming Armv8.6 features. Subsequent patches
will upstream intrinsics support and C-lang support for BFloat.

Reviewers: SjoerdMeijer, rjmccall, rsmith, liutianle, RKSimon, craig.topper, jfb, LukeGeeson, sdesmalen, deadalnix, ctetreau

Subscribers: hiraditya, llvm-commits, danielkiss, arphaman, kristof.beyls, dexonsmith

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D78190

Added: 
    llvm/test/Assembler/bfloat.ll

Modified: 
    clang/lib/Sema/SemaOpenMP.cpp
    llvm/docs/BitCodeFormat.rst
    llvm/docs/LangRef.rst
    llvm/include/llvm-c/Core.h
    llvm/include/llvm/ADT/APFloat.h
    llvm/include/llvm/Bitcode/LLVMBitCodes.h
    llvm/include/llvm/IR/Constants.h
    llvm/include/llvm/IR/DataLayout.h
    llvm/include/llvm/IR/IRBuilder.h
    llvm/include/llvm/IR/Type.h
    llvm/lib/AsmParser/LLLexer.cpp
    llvm/lib/AsmParser/LLParser.cpp
    llvm/lib/Bitcode/Reader/BitcodeReader.cpp
    llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
    llvm/lib/CodeGen/MIRParser/MILexer.cpp
    llvm/lib/IR/AsmWriter.cpp
    llvm/lib/IR/Constants.cpp
    llvm/lib/IR/Core.cpp
    llvm/lib/IR/DataLayout.cpp
    llvm/lib/IR/Function.cpp
    llvm/lib/IR/LLVMContextImpl.cpp
    llvm/lib/IR/LLVMContextImpl.h
    llvm/lib/IR/Type.cpp
    llvm/lib/Support/APFloat.cpp
    llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/tools/llvm-c-test/echo.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 544dc6134387..e03b926bc581 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -14936,9 +14936,9 @@ static bool actOnOMPReductionKindClause(
         if (auto *ComplexTy = OrigType->getAs<ComplexType>())
           Type = ComplexTy->getElementType();
         if (Type->isRealFloatingType()) {
-          llvm::APFloat InitValue =
-              llvm::APFloat::getAllOnesValue(Context.getTypeSize(Type),
-                                             /*isIEEE=*/true);
+          llvm::APFloat InitValue = llvm::APFloat::getAllOnesValue(
+              Context.getFloatTypeSemantics(Type),
+              Context.getTypeSize(Type));
           Init = FloatingLiteral::Create(Context, InitValue, /*isexact=*/true,
                                          Type, ELoc);
         } else if (Type->isScalarType()) {

diff  --git a/llvm/docs/BitCodeFormat.rst b/llvm/docs/BitCodeFormat.rst
index dce84620fd7b..4fdccc87cfd2 100644
--- a/llvm/docs/BitCodeFormat.rst
+++ b/llvm/docs/BitCodeFormat.rst
@@ -1107,6 +1107,14 @@ TYPE_CODE_HALF Record
 The ``HALF`` record (code 10) adds a ``half`` (16-bit floating point) type to
 the type table.
 
+TYPE_CODE_BFLOAT Record
+^^^^^^^^^^^^^^^^^^^^^
+
+``[BFLOAT]``
+
+The ``BFLOAT`` record (code 23) adds a ``bfloat`` (16-bit brain floating point)
+type to the type table.
+
 TYPE_CODE_FLOAT Record
 ^^^^^^^^^^^^^^^^^^^^^^
 

diff  --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 240dbd68e9e0..07320de7cf4b 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -2963,6 +2963,12 @@ Floating-Point Types
    * - ``half``
      - 16-bit floating-point value
 
+   * - ``bfloat``
+     - 16-bit "brain" floating-point value (7-bit significand).  Provides the
+       same number of exponent bits as ``float``, so that it matches its dynamic
+       range, but with greatly reduced precision.  Used in Intel's AVX-512 BF16
+       extensions and Arm's ARMv8.6-A extensions, among others.
+
    * - ``float``
      - 32-bit floating-point value
 
@@ -2970,7 +2976,7 @@ Floating-Point Types
      - 64-bit floating-point value
 
    * - ``fp128``
-     - 128-bit floating-point value (112-bit mantissa)
+     - 128-bit floating-point value (112-bit significand)
 
    * - ``x86_fp80``
      -  80-bit floating-point value (X87)
@@ -3303,20 +3309,20 @@ number of digits. For example, NaN's, infinities, and other special
 values are represented in their IEEE hexadecimal format so that assembly
 and disassembly do not cause any bits to change in the constants.
 
-When using the hexadecimal form, constants of types half, float, and
-double are represented using the 16-digit form shown above (which
-matches the IEEE754 representation for double); half and float values
-must, however, be exactly representable as IEEE 754 half and single
-precision, respectively. Hexadecimal format is always used for long
-double, and there are three forms of long double. The 80-bit format used
-by x86 is represented as ``0xK`` followed by 20 hexadecimal digits. The
-128-bit format used by PowerPC (two adjacent doubles) is represented by
-``0xM`` followed by 32 hexadecimal digits. The IEEE 128-bit format is
-represented by ``0xL`` followed by 32 hexadecimal digits. Long doubles
-will only work if they match the long double format on your target.
-The IEEE 16-bit format (half precision) is represented by ``0xH``
-followed by 4 hexadecimal digits. All hexadecimal formats are big-endian
-(sign bit at the left).
+When using the hexadecimal form, constants of types bfloat, half, float, and
+double are represented using the 16-digit form shown above (which matches the
+IEEE754 representation for double); bfloat, half and float values must, however,
+be exactly representable as bfloat, IEEE 754 half, and IEEE 754 single
+precision respectively. Hexadecimal format is always used for long double, and
+there are three forms of long double. The 80-bit format used by x86 is
+represented as ``0xK`` followed by 20 hexadecimal digits. The 128-bit format
+used by PowerPC (two adjacent doubles) is represented by ``0xM`` followed by 32
+hexadecimal digits. The IEEE 128-bit format is represented by ``0xL`` followed
+by 32 hexadecimal digits. Long doubles will only work if they match the long
+double format on your target.  The IEEE 16-bit format (half precision) is
+represented by ``0xH`` followed by 4 hexadecimal digits. The bfloat 16-bit
+format is represented by ``0xR`` followed by 4 hexadecimal digits. All
+hexadecimal formats are big-endian (sign bit at the left).
 
 There are no constants of type x86_mmx.
 

diff  --git a/llvm/include/llvm-c/Core.h b/llvm/include/llvm-c/Core.h
index 25802edc9982..1991dd98a76b 100644
--- a/llvm/include/llvm-c/Core.h
+++ b/llvm/include/llvm-c/Core.h
@@ -146,6 +146,7 @@ typedef enum {
 typedef enum {
   LLVMVoidTypeKind,        /**< type with no size */
   LLVMHalfTypeKind,        /**< 16 bit floating point type */
+  LLVMBFloatTypeKind,      /**< 16 bit brain floating point type */
   LLVMFloatTypeKind,       /**< 32 bit floating point type */
   LLVMDoubleTypeKind,      /**< 64 bit floating point type */
   LLVMX86_FP80TypeKind,    /**< 80 bit floating point type (X87) */
@@ -1163,6 +1164,11 @@ unsigned LLVMGetIntTypeWidth(LLVMTypeRef IntegerTy);
  */
 LLVMTypeRef LLVMHalfTypeInContext(LLVMContextRef C);
 
+/**
+ * Obtain a 16-bit brain floating point type from a context.
+ */
+LLVMTypeRef LLVMBFloatTypeInContext(LLVMContextRef C);
+
 /**
  * Obtain a 32-bit floating point type from a context.
  */
@@ -1195,6 +1201,7 @@ LLVMTypeRef LLVMPPCFP128TypeInContext(LLVMContextRef C);
  * These map to the functions in this group of the same name.
  */
 LLVMTypeRef LLVMHalfType(void);
+LLVMTypeRef LLVMBFloatType(void);
 LLVMTypeRef LLVMFloatType(void);
 LLVMTypeRef LLVMDoubleType(void);
 LLVMTypeRef LLVMX86FP80Type(void);

diff  --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h
index 1c17f10691e7..44857f777502 100644
--- a/llvm/include/llvm/ADT/APFloat.h
+++ b/llvm/include/llvm/ADT/APFloat.h
@@ -151,6 +151,7 @@ struct APFloatBase {
   /// @{
   enum Semantics {
     S_IEEEhalf,
+    S_BFloat,
     S_IEEEsingle,
     S_IEEEdouble,
     S_x87DoubleExtended,
@@ -162,6 +163,7 @@ struct APFloatBase {
   static Semantics SemanticsToEnum(const llvm::fltSemantics &Sem);
 
   static const fltSemantics &IEEEhalf() LLVM_READNONE;
+  static const fltSemantics &BFloat() LLVM_READNONE;
   static const fltSemantics &IEEEsingle() LLVM_READNONE;
   static const fltSemantics &IEEEdouble() LLVM_READNONE;
   static const fltSemantics &IEEEquad() LLVM_READNONE;
@@ -541,6 +543,7 @@ class IEEEFloat final : public APFloatBase {
   /// @}
 
   APInt convertHalfAPFloatToAPInt() const;
+  APInt convertBFloatAPFloatToAPInt() const;
   APInt convertFloatAPFloatToAPInt() const;
   APInt convertDoubleAPFloatToAPInt() const;
   APInt convertQuadrupleAPFloatToAPInt() const;
@@ -548,6 +551,7 @@ class IEEEFloat final : public APFloatBase {
   APInt convertPPCDoubleDoubleAPFloatToAPInt() const;
   void initFromAPInt(const fltSemantics *Sem, const APInt &api);
   void initFromHalfAPInt(const APInt &api);
+  void initFromBFloatAPInt(const APInt &api);
   void initFromFloatAPInt(const APInt &api);
   void initFromDoubleAPInt(const APInt &api);
   void initFromQuadrupleAPInt(const APInt &api);
@@ -954,9 +958,10 @@ class APFloat : public APFloatBase {
 
   /// Returns a float which is bitcasted from an all one value int.
   ///
+  /// \param Semantics - type float semantics
   /// \param BitWidth - Select float type
-  /// \param isIEEE   - If 128 bit number, select between PPC and IEEE
-  static APFloat getAllOnesValue(unsigned BitWidth, bool isIEEE = false);
+  static APFloat getAllOnesValue(const fltSemantics &Semantics,
+                                 unsigned BitWidth);
 
   /// Used to insert APFloat objects, or objects that contain APFloat objects,
   /// into FoldingSets.

diff  --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index e614337e5852..2f09ad3e7c59 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -166,7 +166,9 @@ enum TypeCodes {
 
   TYPE_CODE_FUNCTION = 21, // FUNCTION: [vararg, retty, paramty x N]
 
-  TYPE_CODE_TOKEN = 22 // TOKEN
+  TYPE_CODE_TOKEN = 22, // TOKEN
+
+  TYPE_CODE_BFLOAT = 23 // BRAIN FLOATING POINT
 };
 
 enum OperandBundleTagCode {

diff  --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index b31bcb751ab7..25d8f6afc1bd 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -721,14 +721,15 @@ class ConstantDataArray final : public ConstantDataSequential {
     return getImpl(Data, Ty);
   }
 
-  /// getFP() constructors - Return a constant with array type with an element
-  /// count and element type of float with precision matching the number of
-  /// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits,
-  /// double for 64bits) Note that this can return a ConstantAggregateZero
-  /// object.
-  static Constant *getFP(LLVMContext &Context, ArrayRef<uint16_t> Elts);
-  static Constant *getFP(LLVMContext &Context, ArrayRef<uint32_t> Elts);
-  static Constant *getFP(LLVMContext &Context, ArrayRef<uint64_t> Elts);
+  /// getFP() constructors - Return a constant of array type with a float
+  /// element type taken from argument `ElementType', and count taken from
+  /// argument `Elts'.  The amount of bits of the contained type must match the
+  /// number of bits of the type contained in the passed in ArrayRef.
+  /// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+  /// that this can return a ConstantAggregateZero object.
+  static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts);
+  static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts);
+  static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts);
 
   /// This method constructs a CDS and initializes it with a text string.
   /// The default behavior (AddNull==true) causes a null terminator to
@@ -780,14 +781,15 @@ class ConstantDataVector final : public ConstantDataSequential {
   static Constant *get(LLVMContext &Context, ArrayRef<float> Elts);
   static Constant *get(LLVMContext &Context, ArrayRef<double> Elts);
 
-  /// getFP() constructors - Return a constant with vector type with an element
-  /// count and element type of float with the precision matching the number of
-  /// bits in the ArrayRef passed in.  (i.e. half for 16bits, float for 32bits,
-  /// double for 64bits) Note that this can return a ConstantAggregateZero
-  /// object.
-  static Constant *getFP(LLVMContext &Context, ArrayRef<uint16_t> Elts);
-  static Constant *getFP(LLVMContext &Context, ArrayRef<uint32_t> Elts);
-  static Constant *getFP(LLVMContext &Context, ArrayRef<uint64_t> Elts);
+  /// getFP() constructors - Return a constant of vector type with a float
+  /// element type taken from argument `ElementType', and count taken from
+  /// argument `Elts'.  The amount of bits of the contained type must match the
+  /// number of bits of the type contained in the passed in ArrayRef.
+  /// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+  /// that this can return a ConstantAggregateZero object.
+  static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts);
+  static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts);
+  static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts);
 
   /// Return a ConstantVector with the specified constant in each element.
   /// The specified constant has to be a of a compatible type (i8/i16/

diff  --git a/llvm/include/llvm/IR/DataLayout.h b/llvm/include/llvm/IR/DataLayout.h
index 010469c8107b..e8fab02fe2ff 100644
--- a/llvm/include/llvm/IR/DataLayout.h
+++ b/llvm/include/llvm/IR/DataLayout.h
@@ -651,6 +651,7 @@ inline TypeSize DataLayout::getTypeSizeInBits(Type *Ty) const {
   case Type::IntegerTyID:
     return TypeSize::Fixed(Ty->getIntegerBitWidth());
   case Type::HalfTyID:
+  case Type::BFloatTyID:
     return TypeSize::Fixed(16);
   case Type::FloatTyID:
     return TypeSize::Fixed(32);

diff  --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index 6e431bc3ac0f..b6dca11527d6 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -477,6 +477,11 @@ class IRBuilderBase {
     return Type::getHalfTy(Context);
   }
 
+  /// Fetch the type representing a 16-bit brain floating point value.
+  Type *getBFloatTy() {
+    return Type::getBFloatTy(Context);
+  }
+
   /// Fetch the type representing a 32-bit floating point value.
   Type *getFloatTy() {
     return Type::getFloatTy(Context);

diff  --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h
index 618eee06dcf7..5d6c0c676f5e 100644
--- a/llvm/include/llvm/IR/Type.h
+++ b/llvm/include/llvm/IR/Type.h
@@ -54,27 +54,28 @@ class Type {
   ///
   enum TypeID {
     // PrimitiveTypes - make sure LastPrimitiveTyID stays up to date.
-    VoidTyID = 0,  ///<  0: type with no size
-    HalfTyID,      ///<  1: 16-bit floating point type
-    FloatTyID,     ///<  2: 32-bit floating point type
-    DoubleTyID,    ///<  3: 64-bit floating point type
-    X86_FP80TyID,  ///<  4: 80-bit floating point type (X87)
-    FP128TyID,     ///<  5: 128-bit floating point type (112-bit mantissa)
-    PPC_FP128TyID, ///<  6: 128-bit floating point type (two 64-bits, PowerPC)
-    LabelTyID,     ///<  7: Labels
-    MetadataTyID,  ///<  8: Metadata
-    X86_MMXTyID,   ///<  9: MMX vectors (64 bits, X86 specific)
-    TokenTyID,     ///< 10: Tokens
+    VoidTyID = 0,    ///<  0: type with no size
+    HalfTyID,        ///<  1: 16-bit floating point type
+    BFloatTyID,      ///<  2: 16-bit floating point type (7-bit significand)
+    FloatTyID,       ///<  3: 32-bit floating point type
+    DoubleTyID,      ///<  4: 64-bit floating point type
+    X86_FP80TyID,    ///<  5: 80-bit floating point type (X87)
+    FP128TyID,       ///<  6: 128-bit floating point type (112-bit significand)
+    PPC_FP128TyID,   ///<  7: 128-bit floating point type (two 64-bits, PowerPC)
+    LabelTyID,       ///<  8: Labels
+    MetadataTyID,    ///<  9: Metadata
+    X86_MMXTyID,     ///< 10: MMX vectors (64 bits, X86 specific)
+    TokenTyID,       ///< 11: Tokens
 
     // Derived types... see DerivedTypes.h file.
     // Make sure FirstDerivedTyID stays up to date!
-    IntegerTyID,       ///< 11: Arbitrary bit width integers
-    FunctionTyID,      ///< 12: Functions
-    StructTyID,        ///< 13: Structures
-    ArrayTyID,         ///< 14: Arrays
-    PointerTyID,       ///< 15: Pointers
-    FixedVectorTyID,   ///< 16: Fixed width SIMD vector type
-    ScalableVectorTyID ///< 17: Scalable SIMD vector type
+    IntegerTyID,       ///< 12: Arbitrary bit width integers
+    FunctionTyID,      ///< 13: Functions
+    StructTyID,        ///< 14: Structures
+    ArrayTyID,         ///< 15: Arrays
+    PointerTyID,       ///< 16: Pointers
+    FixedVectorTyID,   ///< 17: Fixed width SIMD vector type
+    ScalableVectorTyID ///< 18: Scalable SIMD vector type
   };
 
 private:
@@ -140,6 +141,9 @@ class Type {
   /// Return true if this is 'half', a 16-bit IEEE fp type.
   bool isHalfTy() const { return getTypeID() == HalfTyID; }
 
+  /// Return true if this is 'bfloat', a 16-bit bfloat type.
+  bool isBFloatTy() const { return getTypeID() == BFloatTyID; }
+
   /// Return true if this is 'float', a 32-bit IEEE fp type.
   bool isFloatTy() const { return getTypeID() == FloatTyID; }
 
@@ -157,8 +161,8 @@ class Type {
 
   /// Return true if this is one of the six floating-point types
   bool isFloatingPointTy() const {
-    return getTypeID() == HalfTyID || getTypeID() == FloatTyID ||
-           getTypeID() == DoubleTyID ||
+    return getTypeID() == HalfTyID || getTypeID() == BFloatTyID ||
+           getTypeID() == FloatTyID || getTypeID() == DoubleTyID ||
            getTypeID() == X86_FP80TyID || getTypeID() == FP128TyID ||
            getTypeID() == PPC_FP128TyID;
   }
@@ -166,6 +170,7 @@ class Type {
   const fltSemantics &getFltSemantics() const {
     switch (getTypeID()) {
     case HalfTyID: return APFloat::IEEEhalf();
+    case BFloatTyID: return APFloat::BFloat();
     case FloatTyID: return APFloat::IEEEsingle();
     case DoubleTyID: return APFloat::IEEEdouble();
     case X86_FP80TyID: return APFloat::x87DoubleExtended();
@@ -387,6 +392,7 @@ class Type {
   static Type *getVoidTy(LLVMContext &C);
   static Type *getLabelTy(LLVMContext &C);
   static Type *getHalfTy(LLVMContext &C);
+  static Type *getBFloatTy(LLVMContext &C);
   static Type *getFloatTy(LLVMContext &C);
   static Type *getDoubleTy(LLVMContext &C);
   static Type *getMetadataTy(LLVMContext &C);
@@ -422,6 +428,7 @@ class Type {
   // types as pointee.
   //
   static PointerType *getHalfPtrTy(LLVMContext &C, unsigned AS = 0);
+  static PointerType *getBFloatPtrTy(LLVMContext &C, unsigned AS = 0);
   static PointerType *getFloatPtrTy(LLVMContext &C, unsigned AS = 0);
   static PointerType *getDoublePtrTy(LLVMContext &C, unsigned AS = 0);
   static PointerType *getX86_FP80PtrTy(LLVMContext &C, unsigned AS = 0);

diff  --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 06631fc0f7bb..eb85ef7783bd 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -820,6 +820,7 @@ lltok::Kind LLLexer::LexIdentifier() {
 
   TYPEKEYWORD("void",      Type::getVoidTy(Context));
   TYPEKEYWORD("half",      Type::getHalfTy(Context));
+  TYPEKEYWORD("bfloat",    Type::getBFloatTy(Context));
   TYPEKEYWORD("float",     Type::getFloatTy(Context));
   TYPEKEYWORD("double",    Type::getDoubleTy(Context));
   TYPEKEYWORD("x86_fp80",  Type::getX86_FP80Ty(Context));
@@ -985,11 +986,13 @@ lltok::Kind LLLexer::LexIdentifier() {
 ///    HexFP128Constant  0xL[0-9A-Fa-f]+
 ///    HexPPC128Constant 0xM[0-9A-Fa-f]+
 ///    HexHalfConstant   0xH[0-9A-Fa-f]+
+///    HexBFloatConstant 0xR[0-9A-Fa-f]+
 lltok::Kind LLLexer::Lex0x() {
   CurPtr = TokStart + 2;
 
   char Kind;
-  if ((CurPtr[0] >= 'K' && CurPtr[0] <= 'M') || CurPtr[0] == 'H') {
+  if ((CurPtr[0] >= 'K' && CurPtr[0] <= 'M') || CurPtr[0] == 'H' ||
+      CurPtr[0] == 'R') {
     Kind = *CurPtr++;
   } else {
     Kind = 'J';
@@ -1007,7 +1010,7 @@ lltok::Kind LLLexer::Lex0x() {
   if (Kind == 'J') {
     // HexFPConstant - Floating point constant represented in IEEE format as a
     // hexadecimal number for when exponential notation is not precise enough.
-    // Half, Float, and double only.
+    // Half, BFloat, Float, and double only.
     APFloatVal = APFloat(APFloat::IEEEdouble(),
                          APInt(64, HexIntToVal(TokStart + 2, CurPtr)));
     return lltok::APFloat;
@@ -1035,6 +1038,11 @@ lltok::Kind LLLexer::Lex0x() {
     APFloatVal = APFloat(APFloat::IEEEhalf(),
                          APInt(16,HexIntToVal(TokStart+3, CurPtr)));
     return lltok::APFloat;
+  case 'R':
+    // Brain floating point
+    APFloatVal = APFloat(APFloat::BFloat(),
+                         APInt(16, HexIntToVal(TokStart + 3, CurPtr)));
+    return lltok::APFloat;
   }
 }
 

diff  --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index ce1e9d23210e..d045bcd741f5 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -5247,13 +5247,16 @@ bool LLParser::ConvertValIDToValue(Type *Ty, ValID &ID, Value *&V,
         !ConstantFP::isValueValidForType(Ty, ID.APFloatVal))
       return Error(ID.Loc, "floating point constant invalid for type");
 
-    // The lexer has no type info, so builds all half, float, and double FP
-    // constants as double.  Fix this here.  Long double does not need this.
+    // The lexer has no type info, so builds all half, bfloat, float, and double
+    // FP constants as double.  Fix this here.  Long double does not need this.
     if (&ID.APFloatVal.getSemantics() == &APFloat::IEEEdouble()) {
       bool Ignored;
       if (Ty->isHalfTy())
         ID.APFloatVal.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven,
                               &Ignored);
+      else if (Ty->isBFloatTy())
+        ID.APFloatVal.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven,
+                              &Ignored);
       else if (Ty->isFloatTy())
         ID.APFloatVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
                               &Ignored);

diff  --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index bdc0fa7e8a93..21759c5091af 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -1720,6 +1720,9 @@ Error BitcodeReader::parseTypeTableBody() {
     case bitc::TYPE_CODE_HALF:     // HALF
       ResultTy = Type::getHalfTy(Context);
       break;
+    case bitc::TYPE_CODE_BFLOAT:    // BFLOAT
+      ResultTy = Type::getBFloatTy(Context);
+      break;
     case bitc::TYPE_CODE_FLOAT:     // FLOAT
       ResultTy = Type::getFloatTy(Context);
       break;
@@ -2429,6 +2432,9 @@ Error BitcodeReader::parseConstants() {
       if (CurTy->isHalfTy())
         V = ConstantFP::get(Context, APFloat(APFloat::IEEEhalf(),
                                              APInt(16, (uint16_t)Record[0])));
+      else if (CurTy->isBFloatTy())
+        V = ConstantFP::get(Context, APFloat(APFloat::BFloat(),
+                                             APInt(16, (uint32_t)Record[0])));
       else if (CurTy->isFloatTy())
         V = ConstantFP::get(Context, APFloat(APFloat::IEEEsingle(),
                                              APInt(32, (uint32_t)Record[0])));
@@ -2526,21 +2532,27 @@ Error BitcodeReader::parseConstants() {
       } else if (EltTy->isHalfTy()) {
         SmallVector<uint16_t, 16> Elts(Record.begin(), Record.end());
         if (isa<VectorType>(CurTy))
-          V = ConstantDataVector::getFP(Context, Elts);
+          V = ConstantDataVector::getFP(EltTy, Elts);
+        else
+          V = ConstantDataArray::getFP(EltTy, Elts);
+      } else if (EltTy->isBFloatTy()) {
+        SmallVector<uint16_t, 16> Elts(Record.begin(), Record.end());
+        if (isa<VectorType>(CurTy))
+          V = ConstantDataVector::getFP(EltTy, Elts);
         else
-          V = ConstantDataArray::getFP(Context, Elts);
+          V = ConstantDataArray::getFP(EltTy, Elts);
       } else if (EltTy->isFloatTy()) {
         SmallVector<uint32_t, 16> Elts(Record.begin(), Record.end());
         if (isa<VectorType>(CurTy))
-          V = ConstantDataVector::getFP(Context, Elts);
+          V = ConstantDataVector::getFP(EltTy, Elts);
         else
-          V = ConstantDataArray::getFP(Context, Elts);
+          V = ConstantDataArray::getFP(EltTy, Elts);
       } else if (EltTy->isDoubleTy()) {
         SmallVector<uint64_t, 16> Elts(Record.begin(), Record.end());
         if (isa<VectorType>(CurTy))
-          V = ConstantDataVector::getFP(Context, Elts);
+          V = ConstantDataVector::getFP(EltTy, Elts);
         else
-          V = ConstantDataArray::getFP(Context, Elts);
+          V = ConstantDataArray::getFP(EltTy, Elts);
       } else {
         return error("Invalid type for value");
       }

diff  --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 9e389fc64a64..5b62a475203c 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -881,6 +881,7 @@ void ModuleBitcodeWriter::writeTypeTable() {
     switch (T->getTypeID()) {
     case Type::VoidTyID:      Code = bitc::TYPE_CODE_VOID;      break;
     case Type::HalfTyID:      Code = bitc::TYPE_CODE_HALF;      break;
+    case Type::BFloatTyID:    Code = bitc::TYPE_CODE_BFLOAT;    break;
     case Type::FloatTyID:     Code = bitc::TYPE_CODE_FLOAT;     break;
     case Type::DoubleTyID:    Code = bitc::TYPE_CODE_DOUBLE;    break;
     case Type::X86_FP80TyID:  Code = bitc::TYPE_CODE_X86_FP80;  break;
@@ -2387,7 +2388,8 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
     } else if (const ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
       Code = bitc::CST_CODE_FLOAT;
       Type *Ty = CFP->getType();
-      if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) {
+      if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() ||
+          Ty->isDoubleTy()) {
         Record.push_back(CFP->getValueAPF().bitcastToAPInt().getZExtValue());
       } else if (Ty->isX86_FP80Ty()) {
         // api needed to prevent premature destruction

diff  --git a/llvm/lib/CodeGen/MIRParser/MILexer.cpp b/llvm/lib/CodeGen/MIRParser/MILexer.cpp
index e4852d306941..0fbedc424949 100644
--- a/llvm/lib/CodeGen/MIRParser/MILexer.cpp
+++ b/llvm/lib/CodeGen/MIRParser/MILexer.cpp
@@ -534,7 +534,7 @@ static Cursor maybeLexMCSymbol(Cursor C, MIToken &Token,
 }
 
 static bool isValidHexFloatingPointPrefix(char C) {
-  return C == 'H' || C == 'K' || C == 'L' || C == 'M';
+  return C == 'H' || C == 'K' || C == 'L' || C == 'M' || C == 'R';
 }
 
 static Cursor lexFloatingPointLiteral(Cursor Range, Cursor C, MIToken &Token) {

diff  --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 6f451a12c4c0..72da461ddcb8 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -588,6 +588,7 @@ void TypePrinting::print(Type *Ty, raw_ostream &OS) {
   switch (Ty->getTypeID()) {
   case Type::VoidTyID:      OS << "void"; return;
   case Type::HalfTyID:      OS << "half"; return;
+  case Type::BFloatTyID:    OS << "bfloat"; return;
   case Type::FloatTyID:     OS << "float"; return;
   case Type::DoubleTyID:    OS << "double"; return;
   case Type::X86_FP80TyID:  OS << "x86_fp80"; return;
@@ -1379,7 +1380,7 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
       return;
     }
 
-    // Either half, or some form of long double.
+    // Either half, bfloat or some form of long double.
     // These appear as a magic letter identifying the type, then a
     // fixed number of hex digits.
     Out << "0x";
@@ -1407,6 +1408,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
       Out << 'H';
       Out << format_hex_no_prefix(API.getZExtValue(), 4,
                                   /*Upper=*/true);
+    } else if (&APF.getSemantics() == &APFloat::BFloat()) {
+      Out << 'R';
+      Out << format_hex_no_prefix(API.getZExtValue(), 4,
+                                  /*Upper=*/true);
     } else
       llvm_unreachable("Unsupported floating point type");
     return;

diff  --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 5a3c6a44ceb2..88971d89bf4c 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -332,6 +332,9 @@ Constant *Constant::getNullValue(Type *Ty) {
   case Type::HalfTyID:
     return ConstantFP::get(Ty->getContext(),
                            APFloat::getZero(APFloat::IEEEhalf()));
+  case Type::BFloatTyID:
+    return ConstantFP::get(Ty->getContext(),
+                           APFloat::getZero(APFloat::BFloat()));
   case Type::FloatTyID:
     return ConstantFP::get(Ty->getContext(),
                            APFloat::getZero(APFloat::IEEEsingle()));
@@ -386,8 +389,8 @@ Constant *Constant::getAllOnesValue(Type *Ty) {
                             APInt::getAllOnesValue(ITy->getBitWidth()));
 
   if (Ty->isFloatingPointTy()) {
-    APFloat FL = APFloat::getAllOnesValue(Ty->getPrimitiveSizeInBits(),
-                                          !Ty->isPPC_FP128Ty());
+    APFloat FL = APFloat::getAllOnesValue(Ty->getFltSemantics(),
+                                          Ty->getPrimitiveSizeInBits());
     return ConstantFP::get(Ty->getContext(), FL);
   }
 
@@ -763,6 +766,8 @@ void ConstantInt::destroyConstantImpl() {
 static const fltSemantics *TypeToFloatSemantics(Type *Ty) {
   if (Ty->isHalfTy())
     return &APFloat::IEEEhalf();
+  if (Ty->isBFloatTy())
+    return &APFloat::BFloat();
   if (Ty->isFloatTy())
     return &APFloat::IEEEsingle();
   if (Ty->isDoubleTy())
@@ -880,6 +885,8 @@ ConstantFP* ConstantFP::get(LLVMContext &Context, const APFloat& V) {
     Type *Ty;
     if (&V.getSemantics() == &APFloat::IEEEhalf())
       Ty = Type::getHalfTy(Context);
+    else if (&V.getSemantics() == &APFloat::BFloat())
+      Ty = Type::getBFloatTy(Context);
     else if (&V.getSemantics() == &APFloat::IEEEsingle())
       Ty = Type::getFloatTy(Context);
     else if (&V.getSemantics() == &APFloat::IEEEdouble())
@@ -1029,7 +1036,7 @@ static Constant *getFPSequenceIfElementsMatch(ArrayRef<Constant *> V) {
       Elts.push_back(CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
     else
       return nullptr;
-  return SequentialTy::getFP(V[0]->getContext(), Elts);
+  return SequentialTy::getFP(V[0]->getType(), Elts);
 }
 
 template <typename SequenceTy>
@@ -1048,7 +1055,7 @@ static Constant *getSequenceIfElementsMatch(Constant *C,
     else if (CI->getType()->isIntegerTy(64))
       return getIntSequenceIfElementsMatch<SequenceTy, uint64_t>(V);
   } else if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
-    if (CFP->getType()->isHalfTy())
+    if (CFP->getType()->isHalfTy() || CFP->getType()->isBFloatTy())
       return getFPSequenceIfElementsMatch<SequenceTy, uint16_t>(V);
     else if (CFP->getType()->isFloatTy())
       return getFPSequenceIfElementsMatch<SequenceTy, uint32_t>(V);
@@ -1421,6 +1428,12 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat& Val) {
     Val2.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &losesInfo);
     return !losesInfo;
   }
+  case Type::BFloatTyID: {
+    if (&Val2.getSemantics() == &APFloat::BFloat())
+      return true;
+    Val2.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &losesInfo);
+    return !losesInfo;
+  }
   case Type::FloatTyID: {
     if (&Val2.getSemantics() == &APFloat::IEEEsingle())
       return true;
@@ -1429,6 +1442,7 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat& Val) {
   }
   case Type::DoubleTyID: {
     if (&Val2.getSemantics() == &APFloat::IEEEhalf() ||
+        &Val2.getSemantics() == &APFloat::BFloat() ||
         &Val2.getSemantics() == &APFloat::IEEEsingle() ||
         &Val2.getSemantics() == &APFloat::IEEEdouble())
       return true;
@@ -1437,16 +1451,19 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat& Val) {
   }
   case Type::X86_FP80TyID:
     return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+           &Val2.getSemantics() == &APFloat::BFloat() ||
            &Val2.getSemantics() == &APFloat::IEEEsingle() ||
            &Val2.getSemantics() == &APFloat::IEEEdouble() ||
            &Val2.getSemantics() == &APFloat::x87DoubleExtended();
   case Type::FP128TyID:
     return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+           &Val2.getSemantics() == &APFloat::BFloat() ||
            &Val2.getSemantics() == &APFloat::IEEEsingle() ||
            &Val2.getSemantics() == &APFloat::IEEEdouble() ||
            &Val2.getSemantics() == &APFloat::IEEEquad();
   case Type::PPC_FP128TyID:
     return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+           &Val2.getSemantics() == &APFloat::BFloat() ||
            &Val2.getSemantics() == &APFloat::IEEEsingle() ||
            &Val2.getSemantics() == &APFloat::IEEEdouble() ||
            &Val2.getSemantics() == &APFloat::PPCDoubleDouble();
@@ -2562,7 +2579,8 @@ StringRef ConstantDataSequential::getRawDataValues() const {
 }
 
 bool ConstantDataSequential::isElementTypeCompatible(Type *Ty) {
-  if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) return true;
+  if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() || Ty->isDoubleTy())
+    return true;
   if (auto *IT = dyn_cast<IntegerType>(Ty)) {
     switch (IT->getBitWidth()) {
     case 8:
@@ -2680,26 +2698,29 @@ void ConstantDataSequential::destroyConstantImpl() {
   Next = nullptr;
 }
 
-/// getFP() constructors - Return a constant with array type with an element
-/// count and element type of float with precision matching the number of
-/// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits,
-/// double for 64bits) Note that this can return a ConstantAggregateZero
-/// object.
-Constant *ConstantDataArray::getFP(LLVMContext &Context,
-                                   ArrayRef<uint16_t> Elts) {
-  Type *Ty = ArrayType::get(Type::getHalfTy(Context), Elts.size());
+/// getFP() constructors - Return a constant of array type with a float
+/// element type taken from argument `ElementType', and count taken from
+/// argument `Elts'.  The amount of bits of the contained type must match the
+/// number of bits of the type contained in the passed in ArrayRef.
+/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+/// that this can return a ConstantAggregateZero object.
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint16_t> Elts) {
+  assert((ElementType->isHalfTy() || ElementType->isBFloatTy()) &&
+         "Element type is not a 16-bit float type");
+  Type *Ty = ArrayType::get(ElementType, Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(Data, Elts.size() * 2), Ty);
 }
-Constant *ConstantDataArray::getFP(LLVMContext &Context,
-                                   ArrayRef<uint32_t> Elts) {
-  Type *Ty = ArrayType::get(Type::getFloatTy(Context), Elts.size());
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint32_t> Elts) {
+  assert(ElementType->isFloatTy() && "Element type is not a 32-bit float type");
+  Type *Ty = ArrayType::get(ElementType, Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(Data, Elts.size() * 4), Ty);
 }
-Constant *ConstantDataArray::getFP(LLVMContext &Context,
-                                   ArrayRef<uint64_t> Elts) {
-  Type *Ty = ArrayType::get(Type::getDoubleTy(Context), Elts.size());
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint64_t> Elts) {
+  assert(ElementType->isDoubleTy() &&
+         "Element type is not a 64-bit float type");
+  Type *Ty = ArrayType::get(ElementType, Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(Data, Elts.size() * 8), Ty);
 }
@@ -2751,26 +2772,32 @@ Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<double> Elts) {
   return getImpl(StringRef(Data, Elts.size() * 8), Ty);
 }
 
-/// getFP() constructors - Return a constant with vector type with an element
-/// count and element type of float with the precision matching the number of
-/// bits in the ArrayRef passed in.  (i.e. half for 16bits, float for 32bits,
-/// double for 64bits) Note that this can return a ConstantAggregateZero
-/// object.
-Constant *ConstantDataVector::getFP(LLVMContext &Context,
+/// getFP() constructors - Return a constant of vector type with a float
+/// element type taken from argument `ElementType', and count taken from
+/// argument `Elts'.  The amount of bits of the contained type must match the
+/// number of bits of the type contained in the passed in ArrayRef.
+/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+/// that this can return a ConstantAggregateZero object.
+Constant *ConstantDataVector::getFP(Type *ElementType,
                                     ArrayRef<uint16_t> Elts) {
-  Type *Ty = VectorType::get(Type::getHalfTy(Context), Elts.size());
+  assert((ElementType->isHalfTy() || ElementType->isBFloatTy()) &&
+         "Element type is not a 16-bit float type");
+  Type *Ty = VectorType::get(ElementType, Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(Data, Elts.size() * 2), Ty);
 }
-Constant *ConstantDataVector::getFP(LLVMContext &Context,
+Constant *ConstantDataVector::getFP(Type *ElementType,
                                     ArrayRef<uint32_t> Elts) {
-  Type *Ty = VectorType::get(Type::getFloatTy(Context), Elts.size());
+  assert(ElementType->isFloatTy() && "Element type is not a 32-bit float type");
+  Type *Ty = VectorType::get(ElementType, Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(Data, Elts.size() * 4), Ty);
 }
-Constant *ConstantDataVector::getFP(LLVMContext &Context,
+Constant *ConstantDataVector::getFP(Type *ElementType,
                                     ArrayRef<uint64_t> Elts) {
-  Type *Ty = VectorType::get(Type::getDoubleTy(Context), Elts.size());
+  assert(ElementType->isDoubleTy() &&
+         "Element type is not a 64-bit float type");
+  Type *Ty = VectorType::get(ElementType, Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(Data, Elts.size() * 8), Ty);
 }
@@ -2800,17 +2827,22 @@ Constant *ConstantDataVector::getSplat(unsigned NumElts, Constant *V) {
     if (CFP->getType()->isHalfTy()) {
       SmallVector<uint16_t, 16> Elts(
           NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
-      return getFP(V->getContext(), Elts);
+      return getFP(V->getType(), Elts);
+    }
+    if (CFP->getType()->isBFloatTy()) {
+      SmallVector<uint16_t, 16> Elts(
+          NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
+      return getFP(V->getType(), Elts);
     }
     if (CFP->getType()->isFloatTy()) {
       SmallVector<uint32_t, 16> Elts(
           NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
-      return getFP(V->getContext(), Elts);
+      return getFP(V->getType(), Elts);
     }
     if (CFP->getType()->isDoubleTy()) {
       SmallVector<uint64_t, 16> Elts(
           NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
-      return getFP(V->getContext(), Elts);
+      return getFP(V->getType(), Elts);
     }
   }
   return ConstantVector::getSplat({NumElts, false}, V);
@@ -2875,6 +2907,10 @@ APFloat ConstantDataSequential::getElementAsAPFloat(unsigned Elt) const {
     auto EltVal = *reinterpret_cast<const uint16_t *>(EltPtr);
     return APFloat(APFloat::IEEEhalf(), APInt(16, EltVal));
   }
+  case Type::BFloatTyID: {
+    auto EltVal = *reinterpret_cast<const uint16_t *>(EltPtr);
+    return APFloat(APFloat::BFloat(), APInt(16, EltVal));
+  }
   case Type::FloatTyID: {
     auto EltVal = *reinterpret_cast<const uint32_t *>(EltPtr);
     return APFloat(APFloat::IEEEsingle(), APInt(32, EltVal));
@@ -2899,8 +2935,8 @@ double ConstantDataSequential::getElementAsDouble(unsigned Elt) const {
 }
 
 Constant *ConstantDataSequential::getElementAsConstant(unsigned Elt) const {
-  if (getElementType()->isHalfTy() || getElementType()->isFloatTy() ||
-      getElementType()->isDoubleTy())
+  if (getElementType()->isHalfTy() || getElementType()->isBFloatTy() ||
+      getElementType()->isFloatTy() || getElementType()->isDoubleTy())
     return ConstantFP::get(getContext(), getElementAsAPFloat(Elt));
 
   return ConstantInt::get(getElementType(), getElementAsInteger(Elt));

diff  --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp
index 696c25fc4f83..3bb193797f24 100644
--- a/llvm/lib/IR/Core.cpp
+++ b/llvm/lib/IR/Core.cpp
@@ -477,6 +477,8 @@ LLVMTypeKind LLVMGetTypeKind(LLVMTypeRef Ty) {
     return LLVMVoidTypeKind;
   case Type::HalfTyID:
     return LLVMHalfTypeKind;
+  case Type::BFloatTyID:
+    return LLVMBFloatTypeKind;
   case Type::FloatTyID:
     return LLVMFloatTypeKind;
   case Type::DoubleTyID:
@@ -595,6 +597,9 @@ unsigned LLVMGetIntTypeWidth(LLVMTypeRef IntegerTy) {
 LLVMTypeRef LLVMHalfTypeInContext(LLVMContextRef C) {
   return (LLVMTypeRef) Type::getHalfTy(*unwrap(C));
 }
+LLVMTypeRef LLVMBFloatTypeInContext(LLVMContextRef C) {
+  return (LLVMTypeRef) Type::getBFloatTy(*unwrap(C));
+}
 LLVMTypeRef LLVMFloatTypeInContext(LLVMContextRef C) {
   return (LLVMTypeRef) Type::getFloatTy(*unwrap(C));
 }
@@ -617,6 +622,9 @@ LLVMTypeRef LLVMX86MMXTypeInContext(LLVMContextRef C) {
 LLVMTypeRef LLVMHalfType(void) {
   return LLVMHalfTypeInContext(LLVMGetGlobalContext());
 }
+LLVMTypeRef LLVMBFloatType(void) {
+  return LLVMBFloatTypeInContext(LLVMGetGlobalContext());
+}
 LLVMTypeRef LLVMFloatType(void) {
   return LLVMFloatTypeInContext(LLVMGetGlobalContext());
 }

diff  --git a/llvm/lib/IR/DataLayout.cpp b/llvm/lib/IR/DataLayout.cpp
index 0a25f1c14da0..87563d988ed7 100644
--- a/llvm/lib/IR/DataLayout.cpp
+++ b/llvm/lib/IR/DataLayout.cpp
@@ -162,7 +162,7 @@ static const LayoutAlignElem DefaultAlignments[] = {
     {INTEGER_ALIGN, 16, Align(2), Align(2)},   // i16
     {INTEGER_ALIGN, 32, Align(4), Align(4)},   // i32
     {INTEGER_ALIGN, 64, Align(4), Align(8)},   // i64
-    {FLOAT_ALIGN, 16, Align(2), Align(2)},     // half
+    {FLOAT_ALIGN, 16, Align(2), Align(2)},     // half, bfloat
     {FLOAT_ALIGN, 32, Align(4), Align(4)},     // float
     {FLOAT_ALIGN, 64, Align(8), Align(8)},     // double
     {FLOAT_ALIGN, 128, Align(16), Align(16)},  // ppcf128, quad, ...
@@ -732,6 +732,7 @@ Align DataLayout::getAlignment(Type *Ty, bool abi_or_pref) const {
     AlignType = INTEGER_ALIGN;
     break;
   case Type::HalfTyID:
+  case Type::BFloatTyID:
   case Type::FloatTyID:
   case Type::DoubleTyID:
   // PPC_FP128TyID and FP128TyID have 
diff erent data contents, but the

diff  --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index dab1c336c419..7bf3ab59b877 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -655,6 +655,7 @@ static std::string getMangledTypeStr(Type* Ty) {
     case Type::VoidTyID:      Result += "isVoid";   break;
     case Type::MetadataTyID:  Result += "Metadata"; break;
     case Type::HalfTyID:      Result += "f16";      break;
+    case Type::BFloatTyID:    Result += "bf16";     break;
     case Type::FloatTyID:     Result += "f32";      break;
     case Type::DoubleTyID:    Result += "f64";      break;
     case Type::X86_FP80TyID:  Result += "f80";      break;

diff  --git a/llvm/lib/IR/LLVMContextImpl.cpp b/llvm/lib/IR/LLVMContextImpl.cpp
index 68b8f8aef680..f197b3e67d30 100644
--- a/llvm/lib/IR/LLVMContextImpl.cpp
+++ b/llvm/lib/IR/LLVMContextImpl.cpp
@@ -26,6 +26,7 @@ LLVMContextImpl::LLVMContextImpl(LLVMContext &C)
     VoidTy(C, Type::VoidTyID),
     LabelTy(C, Type::LabelTyID),
     HalfTy(C, Type::HalfTyID),
+    BFloatTy(C, Type::BFloatTyID),
     FloatTy(C, Type::FloatTyID),
     DoubleTy(C, Type::DoubleTyID),
     MetadataTy(C, Type::MetadataTyID),

diff  --git a/llvm/lib/IR/LLVMContextImpl.h b/llvm/lib/IR/LLVMContextImpl.h
index a019f1ee07b9..9912808c53c2 100644
--- a/llvm/lib/IR/LLVMContextImpl.h
+++ b/llvm/lib/IR/LLVMContextImpl.h
@@ -1342,7 +1342,8 @@ class LLVMContextImpl {
   std::unique_ptr<ConstantTokenNone> TheNoneToken;
 
   // Basic type instances.
-  Type VoidTy, LabelTy, HalfTy, FloatTy, DoubleTy, MetadataTy, TokenTy;
+  Type VoidTy, LabelTy, HalfTy, BFloatTy, FloatTy, DoubleTy, MetadataTy,
+      TokenTy;
   Type X86_FP80Ty, FP128Ty, PPC_FP128Ty, X86_MMXTy;
   IntegerType Int1Ty, Int8Ty, Int16Ty, Int32Ty, Int64Ty, Int128Ty;
 

diff  --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp
index 8e5b03211132..bb077c1fd019 100644
--- a/llvm/lib/IR/Type.cpp
+++ b/llvm/lib/IR/Type.cpp
@@ -40,6 +40,7 @@ Type *Type::getPrimitiveType(LLVMContext &C, TypeID IDNumber) {
   switch (IDNumber) {
   case VoidTyID      : return getVoidTy(C);
   case HalfTyID      : return getHalfTy(C);
+  case BFloatTyID    : return getBFloatTy(C);
   case FloatTyID     : return getFloatTy(C);
   case DoubleTyID    : return getDoubleTy(C);
   case X86_FP80TyID  : return getX86_FP80Ty(C);
@@ -112,6 +113,7 @@ bool Type::isEmptyTy() const {
 TypeSize Type::getPrimitiveSizeInBits() const {
   switch (getTypeID()) {
   case Type::HalfTyID: return TypeSize::Fixed(16);
+  case Type::BFloatTyID: return TypeSize::Fixed(16);
   case Type::FloatTyID: return TypeSize::Fixed(32);
   case Type::DoubleTyID: return TypeSize::Fixed(64);
   case Type::X86_FP80TyID: return TypeSize::Fixed(80);
@@ -142,6 +144,7 @@ int Type::getFPMantissaWidth() const {
     return VTy->getElementType()->getFPMantissaWidth();
   assert(isFloatingPointTy() && "Not a floating point type!");
   if (getTypeID() == HalfTyID) return 11;
+  if (getTypeID() == BFloatTyID) return 8;
   if (getTypeID() == FloatTyID) return 24;
   if (getTypeID() == DoubleTyID) return 53;
   if (getTypeID() == X86_FP80TyID) return 64;
@@ -167,6 +170,7 @@ bool Type::isSizedDerivedType(SmallPtrSetImpl<Type*> *Visited) const {
 Type *Type::getVoidTy(LLVMContext &C) { return &C.pImpl->VoidTy; }
 Type *Type::getLabelTy(LLVMContext &C) { return &C.pImpl->LabelTy; }
 Type *Type::getHalfTy(LLVMContext &C) { return &C.pImpl->HalfTy; }
+Type *Type::getBFloatTy(LLVMContext &C) { return &C.pImpl->BFloatTy; }
 Type *Type::getFloatTy(LLVMContext &C) { return &C.pImpl->FloatTy; }
 Type *Type::getDoubleTy(LLVMContext &C) { return &C.pImpl->DoubleTy; }
 Type *Type::getMetadataTy(LLVMContext &C) { return &C.pImpl->MetadataTy; }
@@ -191,6 +195,10 @@ PointerType *Type::getHalfPtrTy(LLVMContext &C, unsigned AS) {
   return getHalfTy(C)->getPointerTo(AS);
 }
 
+PointerType *Type::getBFloatPtrTy(LLVMContext &C, unsigned AS) {
+  return getBFloatTy(C)->getPointerTo(AS);
+}
+
 PointerType *Type::getFloatPtrTy(LLVMContext &C, unsigned AS) {
   return getFloatTy(C)->getPointerTo(AS);
 }

diff  --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index 63114faae117..78f44c5e47bb 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -69,6 +69,7 @@ namespace llvm {
   };
 
   static const fltSemantics semIEEEhalf = {15, -14, 11, 16};
+  static const fltSemantics semBFloat = {127, -126, 8, 16};
   static const fltSemantics semIEEEsingle = {127, -126, 24, 32};
   static const fltSemantics semIEEEdouble = {1023, -1022, 53, 64};
   static const fltSemantics semIEEEquad = {16383, -16382, 113, 128};
@@ -117,6 +118,8 @@ namespace llvm {
     switch (S) {
     case S_IEEEhalf:
       return IEEEhalf();
+    case S_BFloat:
+      return BFloat();
     case S_IEEEsingle:
       return IEEEsingle();
     case S_IEEEdouble:
@@ -135,6 +138,8 @@ namespace llvm {
   APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
     if (&Sem == &llvm::APFloat::IEEEhalf())
       return S_IEEEhalf;
+    else if (&Sem == &llvm::APFloat::BFloat())
+      return S_BFloat;
     else if (&Sem == &llvm::APFloat::IEEEsingle())
       return S_IEEEsingle;
     else if (&Sem == &llvm::APFloat::IEEEdouble())
@@ -152,6 +157,9 @@ namespace llvm {
   const fltSemantics &APFloatBase::IEEEhalf() {
     return semIEEEhalf;
   }
+  const fltSemantics &APFloatBase::BFloat() {
+    return semBFloat;
+  }
   const fltSemantics &APFloatBase::IEEEsingle() {
     return semIEEEsingle;
   }
@@ -3255,6 +3263,33 @@ APInt IEEEFloat::convertFloatAPFloatToAPInt() const {
                     (mysignificand & 0x7fffff)));
 }
 
+APInt IEEEFloat::convertBFloatAPFloatToAPInt() const {
+  assert(semantics == (const llvm::fltSemantics *)&semBFloat);
+  assert(partCount() == 1);
+
+  uint32_t myexponent, mysignificand;
+
+  if (isFiniteNonZero()) {
+    myexponent = exponent + 127; // bias
+    mysignificand = (uint32_t)*significandParts();
+    if (myexponent == 1 && !(mysignificand & 0x80))
+      myexponent = 0; // denormal
+  } else if (category == fcZero) {
+    myexponent = 0;
+    mysignificand = 0;
+  } else if (category == fcInfinity) {
+    myexponent = 0x1f;
+    mysignificand = 0;
+  } else {
+    assert(category == fcNaN && "Unknown category!");
+    myexponent = 0x1f;
+    mysignificand = (uint32_t)*significandParts();
+  }
+
+  return APInt(16, (((sign & 1) << 15) | ((myexponent & 0xff) << 7) |
+                    (mysignificand & 0x7f)));
+}
+
 APInt IEEEFloat::convertHalfAPFloatToAPInt() const {
   assert(semantics == (const llvm::fltSemantics*)&semIEEEhalf);
   assert(partCount()==1);
@@ -3290,6 +3325,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
   if (semantics == (const llvm::fltSemantics*)&semIEEEhalf)
     return convertHalfAPFloatToAPInt();
 
+  if (semantics == (const llvm::fltSemantics *)&semBFloat)
+    return convertBFloatAPFloatToAPInt();
+
   if (semantics == (const llvm::fltSemantics*)&semIEEEsingle)
     return convertFloatAPFloatToAPInt();
 
@@ -3486,6 +3524,37 @@ void IEEEFloat::initFromFloatAPInt(const APInt &api) {
   }
 }
 
+void IEEEFloat::initFromBFloatAPInt(const APInt &api) {
+  assert(api.getBitWidth() == 16);
+  uint32_t i = (uint32_t)*api.getRawData();
+  uint32_t myexponent = (i >> 7) & 0xff;
+  uint32_t mysignificand = i & 0x7f;
+
+  initialize(&semBFloat);
+  assert(partCount() == 1);
+
+  sign = i >> 15;
+  if (myexponent == 0 && mysignificand == 0) {
+    // exponent, significand meaningless
+    category = fcZero;
+  } else if (myexponent == 0xff && mysignificand == 0) {
+    // exponent, significand meaningless
+    category = fcInfinity;
+  } else if (myexponent == 0xff && mysignificand != 0) {
+    // sign, exponent, significand meaningless
+    category = fcNaN;
+    *significandParts() = mysignificand;
+  } else {
+    category = fcNormal;
+    exponent = myexponent - 127; // bias
+    *significandParts() = mysignificand;
+    if (myexponent == 0) // denormal
+      exponent = -126;
+    else
+      *significandParts() |= 0x80; // integer bit
+  }
+}
+
 void IEEEFloat::initFromHalfAPInt(const APInt &api) {
   assert(api.getBitWidth()==16);
   uint32_t i = (uint32_t)*api.getRawData();
@@ -3524,6 +3593,8 @@ void IEEEFloat::initFromHalfAPInt(const APInt &api) {
 void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
   if (Sem == &semIEEEhalf)
     return initFromHalfAPInt(api);
+  if (Sem == &semBFloat)
+    return initFromBFloatAPInt(api);
   if (Sem == &semIEEEsingle)
     return initFromFloatAPInt(api);
   if (Sem == &semIEEEdouble)
@@ -4763,26 +4834,9 @@ APFloat::opStatus APFloat::convert(const fltSemantics &ToSemantics,
   llvm_unreachable("Unexpected semantics");
 }
 
-APFloat APFloat::getAllOnesValue(unsigned BitWidth, bool isIEEE) {
-  if (isIEEE) {
-    switch (BitWidth) {
-    case 16:
-      return APFloat(semIEEEhalf, APInt::getAllOnesValue(BitWidth));
-    case 32:
-      return APFloat(semIEEEsingle, APInt::getAllOnesValue(BitWidth));
-    case 64:
-      return APFloat(semIEEEdouble, APInt::getAllOnesValue(BitWidth));
-    case 80:
-      return APFloat(semX87DoubleExtended, APInt::getAllOnesValue(BitWidth));
-    case 128:
-      return APFloat(semIEEEquad, APInt::getAllOnesValue(BitWidth));
-    default:
-      llvm_unreachable("Unknown floating bit width");
-    }
-  } else {
-    assert(BitWidth == 128);
-    return APFloat(semPPCDoubleDouble, APInt::getAllOnesValue(BitWidth));
-  }
+APFloat APFloat::getAllOnesValue(const fltSemantics &Semantics,
+                                 unsigned BitWidth) {
+  return APFloat(Semantics, APInt::getAllOnesValue(BitWidth));
 }
 
 void APFloat::print(raw_ostream &OS) const {

diff  --git a/llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp b/llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp
index 97aee3a10207..0af26255109c 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp
@@ -323,6 +323,7 @@ unsigned HexagonTargetObjectFile::getSmallestAddressableSize(const Type *Ty,
   }
   case Type::FunctionTyID:
   case Type::VoidTyID:
+  case Type::BFloatTyID:
   case Type::X86_FP80TyID:
   case Type::FP128TyID:
   case Type::PPC_FP128TyID:

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 29ced3991505..249f7b25f932 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -11526,8 +11526,9 @@ static SDValue lowerShuffleAsBitMask(const SDLoc &DL, MVT VT, SDValue V1,
   MVT LogicVT = VT;
   if (EltVT == MVT::f32 || EltVT == MVT::f64) {
     Zero = DAG.getConstantFP(0.0, DL, EltVT);
-    AllOnes = DAG.getConstantFP(
-        APFloat::getAllOnesValue(EltVT.getSizeInBits(), true), DL, EltVT);
+    APFloat AllOnesValue = APFloat::getAllOnesValue(
+        SelectionDAG::EVTToAPFloatSemantics(EltVT), EltVT.getSizeInBits());
+    AllOnes = DAG.getConstantFP(AllOnesValue, DL, EltVT);
     LogicVT =
         MVT::getVectorVT(EltVT == MVT::f64 ? MVT::i64 : MVT::i32, Mask.size());
   } else {

diff  --git a/llvm/test/Assembler/bfloat.ll b/llvm/test/Assembler/bfloat.ll
new file mode 100644
index 000000000000..c9c7b6d26868
--- /dev/null
+++ b/llvm/test/Assembler/bfloat.ll
@@ -0,0 +1,38 @@
+; RUN: llvm-as < %s | llvm-dis | FileCheck %s --check-prefix=ASSEM-DISASS
+; RUN: opt < %s -O3 -S | FileCheck %s --check-prefix=OPT
+; RUN: verify-uselistorder %s
+; Basic smoke tests for bfloat type.
+
+define bfloat @check_bfloat(bfloat %A) {
+; ASSEM-DISASS: ret bfloat %A
+    ret bfloat %A
+}
+
+define bfloat @check_bfloat_literal() {
+; ASSEM-DISASS: ret bfloat 0xR3149
+    ret bfloat 0xR3149
+}
+
+define <4 x bfloat> @check_fixed_vector() {
+; ASSEM-DISASS: ret <4 x bfloat> %tmp
+  %tmp = fadd <4 x bfloat> undef, undef
+  ret <4 x bfloat> %tmp
+}
+
+define <vscale x 4 x bfloat> @check_vector() {
+; ASSEM-DISASS: ret <vscale x 4 x bfloat> %tmp
+  %tmp = fadd <vscale x 4 x bfloat> undef, undef
+  ret <vscale x 4 x bfloat> %tmp
+}
+
+define bfloat @check_bfloat_constprop() {
+  %tmp = fadd bfloat 0xR40C0, 0xR40C0
+; OPT: 0xR4140
+  ret bfloat %tmp
+}
+
+define float @check_bfloat_convert() {
+  %tmp = fpext bfloat 0xR4C8D to float
+; OPT: 0x4191A00000000000
+  ret float %tmp
+}

diff  --git a/llvm/tools/llvm-c-test/echo.cpp b/llvm/tools/llvm-c-test/echo.cpp
index bf284da95935..49b9f74382b9 100644
--- a/llvm/tools/llvm-c-test/echo.cpp
+++ b/llvm/tools/llvm-c-test/echo.cpp
@@ -72,6 +72,8 @@ struct TypeCloner {
         return LLVMVoidTypeInContext(Ctx);
       case LLVMHalfTypeKind:
         return LLVMHalfTypeInContext(Ctx);
+      case LLVMBFloatTypeKind:
+        return LLVMHalfTypeInContext(Ctx);
       case LLVMFloatTypeKind:
         return LLVMFloatTypeInContext(Ctx);
       case LLVMDoubleTypeKind:


        


More information about the cfe-commits mailing list