[clang] [llvm] [RFC][AMDGPU] Use `bf16` instead of `i16` for bfloat (PR #80908)

Shilei Tian via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 12 10:48:50 PST 2024


https://github.com/shiltian updated https://github.com/llvm/llvm-project/pull/80908

>From 4196e998349d663a9a9922937cc4bedbec95fe5f Mon Sep 17 00:00:00 2001
From: Shilei Tian <i at tianshilei.me>
Date: Mon, 12 Feb 2024 13:48:39 -0500
Subject: [PATCH] [RFC][WIP][AMDGPU] Use `bf16` instead of `i16` for bfloat

Currently it looks like we generally use `i16` to represent `bf16` in those tablegen
files. I'm not sure of the reason behind it. My wild guess is the type `bf16` was
not available when we enabled the support. This patch is trying to use `bf16`
directly in those tablegen files, aiming at fixing #79369. Of course for #79369
a workaround can be to treat all `INT16` variants as `BFloat` in `getOpFltSemantics`,
but it doesn't look good IMHO.

Since I'm fairly new to AMDGPU backend, I'd appreciate it if you can point out
where I don't understand correctly.
---
 clang/lib/CodeGen/CGBuiltin.cpp               |  4 --
 llvm/include/llvm/IR/IntrinsicsAMDGPU.td      |  8 +--
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp  |  5 +-
 .../AMDGPU/AsmParser/AMDGPUAsmParser.cpp      | 71 +++++++++++++++++++
 .../AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp | 59 +++++++++++++++
 .../AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h   |  2 +
 .../MCTargetDesc/AMDGPUMCCodeEmitter.cpp      |  7 ++
 llvm/lib/Target/AMDGPU/SIDefines.h            |  7 ++
 llvm/lib/Target/AMDGPU/SIInstrInfo.cpp        |  8 +++
 llvm/lib/Target/AMDGPU/SIInstrInfo.td         | 58 ++++++++-------
 llvm/lib/Target/AMDGPU/SIRegisterInfo.td      | 22 +++++-
 .../Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp    | 48 ++++++++++---
 llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h | 15 ++++
 llvm/lib/Target/AMDGPU/VOP3Instructions.td    |  2 +-
 .../AMDGPU/llvm.amdgcn.fdot2.bf16.bf16.ll     | 36 +++++-----
 llvm/test/MC/AMDGPU/bf16_imm.s                |  8 +++
 16 files changed, 295 insertions(+), 65 deletions(-)
 create mode 100644 llvm/test/MC/AMDGPU/bf16_imm.s

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index a7a410dab1a018..daf651917f2a96 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -5908,8 +5908,6 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
           }
         }
 
-        assert(ArgValue->getType()->canLosslesslyBitCastTo(PTy) &&
-               "Must be able to losslessly bit cast to param");
         // Cast vector type (e.g., v256i32) to x86_amx, this only happen
         // in amx intrinsics.
         if (PTy->isX86_AMXTy())
@@ -5939,8 +5937,6 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
         }
       }
 
-      assert(V->getType()->canLosslesslyBitCastTo(RetTy) &&
-             "Must be able to losslessly bit cast result type");
       // Cast x86_amx to vector type (e.g., v256i32), this only happen
       // in amx intrinsics.
       if (V->getType()->isX86_AMXTy())
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index 202fa4e8f4ea81..6795fb7aa0edb8 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -2819,11 +2819,11 @@ def int_amdgcn_fdot2_f16_f16 :
 def int_amdgcn_fdot2_bf16_bf16 :
   ClangBuiltin<"__builtin_amdgcn_fdot2_bf16_bf16">,
   DefaultAttrsIntrinsic<
-    [llvm_i16_ty],   // %r
+    [llvm_bfloat_ty],   // %r
     [
-      llvm_v2i16_ty, // %a
-      llvm_v2i16_ty, // %b
-      llvm_i16_ty    // %c
+      llvm_v2bf16_ty, // %a
+      llvm_v2bf16_ty, // %b
+      llvm_bfloat_ty    // %c
     ],
     [IntrNoMem, IntrSpeculatable]
   >;
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index c1d8e890a66edb..828229f3e569e3 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -1562,8 +1562,9 @@ bool IRTranslator::translateBitCast(const User &U,
 
 bool IRTranslator::translateCast(unsigned Opcode, const User &U,
                                  MachineIRBuilder &MIRBuilder) {
-  if (U.getType()->getScalarType()->isBFloatTy() ||
-      U.getOperand(0)->getType()->getScalarType()->isBFloatTy())
+  if (Opcode != TargetOpcode::G_BITCAST &&
+      (U.getType()->getScalarType()->isBFloatTy() ||
+       U.getOperand(0)->getType()->getScalarType()->isBFloatTy()))
     return false;
   Register Op = getOrCreateVReg(*U.getOperand(0));
   Register Res = getOrCreateVReg(U);
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index a94da992b33859..65d6fb587c19ca 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -475,6 +475,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
 
   bool isSSrcF64() const { return isSCSrc_b64() || isLiteralImm(MVT::f64); }
 
+  bool isSSrc_bf16() const { return isSCSrcB16() || isLiteralImm(MVT::bf16); }
+
   bool isSSrc_f16() const { return isSCSrcB16() || isLiteralImm(MVT::f16); }
 
   bool isSSrcV2F16() const {
@@ -541,22 +543,40 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     return isRegOrInlineNoMods(AMDGPU::VS_64RegClassID, MVT::f64);
   }
 
+  bool isVCSrcTBF16() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::bf16);
+  }
+
   bool isVCSrcTF16() const {
     return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::f16);
   }
 
+  bool isVCSrcTBF16_Lo128() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::bf16);
+  }
+
   bool isVCSrcTF16_Lo128() const {
     return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::f16);
   }
 
+  bool isVCSrcFake16BF16_Lo128() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::bf16);
+  }
+
   bool isVCSrcFake16F16_Lo128() const {
     return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::f16);
   }
 
+  bool isVCSrc_bf16() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::bf16);
+  }
+
   bool isVCSrc_f16() const {
     return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::f16);
   }
 
+  bool isVCSrc_v2bf16() const { return isVCSrc_bf16(); }
+
   bool isVCSrc_v2f16() const { return isVCSrc_f16(); }
 
   bool isVSrc_b32() const {
@@ -597,18 +617,34 @@ class AMDGPUOperand : public MCParsedAsmOperand {
 
   bool isVSrc_f64() const { return isVCSrcF64() || isLiteralImm(MVT::f64); }
 
+  bool isVSrcT_bf16() const { return isVCSrcTBF16() || isLiteralImm(MVT::bf16); }
+
   bool isVSrcT_f16() const { return isVCSrcTF16() || isLiteralImm(MVT::f16); }
 
+  bool isVSrcT_bf16_Lo128() const {
+    return isVCSrcTBF16_Lo128() || isLiteralImm(MVT::bf16);
+  }
+
   bool isVSrcT_f16_Lo128() const {
     return isVCSrcTF16_Lo128() || isLiteralImm(MVT::f16);
   }
 
+  bool isVSrcFake16_bf16_Lo128() const {
+    return isVCSrcFake16BF16_Lo128() || isLiteralImm(MVT::bf16);
+  }
+
   bool isVSrcFake16_f16_Lo128() const {
     return isVCSrcFake16F16_Lo128() || isLiteralImm(MVT::f16);
   }
 
+  bool isVSrc_bf16() const { return isVCSrc_bf16() || isLiteralImm(MVT::bf16); }
+
   bool isVSrc_f16() const { return isVCSrc_f16() || isLiteralImm(MVT::f16); }
 
+  bool isVSrc_v2bf16() const {
+    return isVSrc_bf16() || isLiteralImm(MVT::v2bf16);
+  }
+
   bool isVSrc_v2f16() const { return isVSrc_f16() || isLiteralImm(MVT::v2f16); }
 
   bool isVISrcB32() const {
@@ -635,6 +671,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     return isVISrcF16() || isVISrcB32();
   }
 
+  bool isVISrc_64_bf16() const {
+    return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::bf16);
+  }
+
   bool isVISrc_64_f16() const {
     return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::f16);
   }
@@ -803,6 +843,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     return isAISrc_128F16() || isAISrc_128_b32();
   }
 
+  bool isVISrc_128_bf16() const {
+    return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::bf16);
+  }
+
   bool isVISrc_128_f16() const {
     return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::f16);
   }
@@ -1890,6 +1934,14 @@ static const fltSemantics *getOpFltSemantics(uint8_t OperandType) {
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
   case AMDGPU::OPERAND_KIMM16:
     return &APFloat::IEEEhalf();
+  case AMDGPU::OPERAND_REG_IMM_BF16:
+  case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
+  case AMDGPU::OPERAND_REG_INLINE_C_BF16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
+    return &APFloat::BFloat();
   default:
     llvm_unreachable("unsupported fp type");
   }
@@ -2186,17 +2238,24 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
     case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
     case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
     case AMDGPU::OPERAND_REG_IMM_INT16:
+    case AMDGPU::OPERAND_REG_IMM_BF16:
     case AMDGPU::OPERAND_REG_IMM_FP16:
+    case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
     case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
     case AMDGPU::OPERAND_REG_INLINE_C_INT16:
+    case AMDGPU::OPERAND_REG_INLINE_C_BF16:
     case AMDGPU::OPERAND_REG_INLINE_C_FP16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
+    case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
     case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
+    case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
     case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
     case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
+    case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
     case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
     case AMDGPU::OPERAND_REG_IMM_V2INT16:
+    case AMDGPU::OPERAND_REG_IMM_V2BF16:
     case AMDGPU::OPERAND_REG_IMM_V2FP16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2FP32:
     case AMDGPU::OPERAND_REG_IMM_V2FP32:
@@ -2240,6 +2299,7 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
   case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
   case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
   case AMDGPU::OPERAND_REG_IMM_V2INT16:
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
   case AMDGPU::OPERAND_REG_IMM_V2FP32:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP32:
@@ -2277,11 +2337,15 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
     return;
 
   case AMDGPU::OPERAND_REG_IMM_INT16:
+  case AMDGPU::OPERAND_REG_IMM_BF16:
   case AMDGPU::OPERAND_REG_IMM_FP16:
+  case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
   case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
   case AMDGPU::OPERAND_REG_INLINE_C_INT16:
+  case AMDGPU::OPERAND_REG_INLINE_C_BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_FP16:
   case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
     if (isSafeTruncation(Val, 16) &&
         AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
@@ -2296,8 +2360,10 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
     return;
 
   case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16: {
     assert(isSafeTruncation(Val, 16));
     assert(AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
@@ -3434,6 +3500,11 @@ bool AMDGPUAsmParser::isInlineConstant(const MCInst &Inst,
         OperandType == AMDGPU::OPERAND_REG_IMM_V2FP16)
       return AMDGPU::isInlinableLiteralV2F16(Val);
 
+    if (OperandType == AMDGPU::OPERAND_REG_INLINE_C_V2BF16 ||
+        OperandType == AMDGPU::OPERAND_REG_INLINE_AC_V2BF16 ||
+        OperandType == AMDGPU::OPERAND_REG_IMM_V2BF16)
+      return AMDGPU::isInlinableLiteralV2BF16(Val);
+
     return AMDGPU::isInlinableLiteral16(Val, hasInv2PiInlineImm());
   }
   default:
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
index 4ab3aa5a0240ad..e1a837eb230d1e 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
@@ -488,6 +488,49 @@ static bool printImmediateFloat16(uint32_t Imm, const MCSubtargetInfo &STI,
   return true;
 }
 
+static bool printImmediateBFloat16(uint32_t Imm, const MCSubtargetInfo &STI,
+                                   raw_ostream &O) {
+  if (Imm == 0x3F80)
+    O << "1.0";
+  else if (Imm == 0xBF80)
+    O << "-1.0";
+  else if (Imm == 0x3F00)
+    O << "0.5";
+  else if (Imm == 0xBF00)
+    O << "-0.5";
+  else if (Imm == 0x4000)
+    O << "2.0";
+  else if (Imm == 0xC000)
+    O << "-2.0";
+  else if (Imm == 0x4080)
+    O << "4.0";
+  else if (Imm == 0xC080)
+    O << "-4.0";
+  else if (Imm == 0x3E22 && STI.hasFeature(AMDGPU::FeatureInv2PiInlineImm))
+    O << "0.15915494";
+  else
+    return false;
+
+  return true;
+}
+
+void AMDGPUInstPrinter::printImmediateBF16(uint32_t Imm,
+                                           const MCSubtargetInfo &STI,
+                                           raw_ostream &O) {
+  int16_t SImm = static_cast<int16_t>(Imm);
+  if (isInlinableIntLiteral(SImm)) {
+    O << SImm;
+    return;
+  }
+
+  uint16_t HImm = static_cast<uint16_t>(Imm);
+  if (printImmediateBFloat16(HImm, STI, O))
+    return;
+
+  uint64_t Imm16 = static_cast<uint16_t>(Imm);
+  O << formatHex(Imm16);
+}
+
 void AMDGPUInstPrinter::printImmediate16(uint32_t Imm,
                                          const MCSubtargetInfo &STI,
                                          raw_ostream &O) {
@@ -528,6 +571,13 @@ void AMDGPUInstPrinter::printImmediateV216(uint32_t Imm, uint8_t OpType,
         printImmediateFloat16(static_cast<uint16_t>(Imm), STI, O))
       return;
     break;
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
+    if (isUInt<16>(Imm) &&
+        printImmediateBFloat16(static_cast<uint16_t>(Imm), STI, O))
+      return;
+    break;
   default:
     llvm_unreachable("bad operand type");
   }
@@ -799,11 +849,20 @@ void AMDGPUInstPrinter::printRegularOperand(const MCInst *MI, unsigned OpNo,
     case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
       printImmediate16(Op.getImm(), STI, O);
       break;
+    case AMDGPU::OPERAND_REG_INLINE_C_BF16:
+    case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
+    case AMDGPU::OPERAND_REG_IMM_BF16:
+    case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
+      printImmediateBF16(Op.getImm(), STI, O);
+      break;
     case AMDGPU::OPERAND_REG_IMM_V2INT16:
+    case AMDGPU::OPERAND_REG_IMM_V2BF16:
     case AMDGPU::OPERAND_REG_IMM_V2FP16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
     case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
+    case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
+    case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
     case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
       printImmediateV216(Op.getImm(), OpTy, STI, O);
       break;
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
index e91ff86b219a0c..15ecbf2e5e5918 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
@@ -88,6 +88,8 @@ class AMDGPUInstPrinter : public MCInstPrinter {
                            raw_ostream &O);
   void printImmediate16(uint32_t Imm, const MCSubtargetInfo &STI,
                         raw_ostream &O);
+  void printImmediateBF16(uint32_t Imm, const MCSubtargetInfo &STI,
+                          raw_ostream &O);
   void printImmediateV216(uint32_t Imm, uint8_t OpType,
                           const MCSubtargetInfo &STI, raw_ostream &O);
   bool printImmediateFloat32(uint32_t Imm, const MCSubtargetInfo &STI,
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
index 11f5e456e8d348..9ec174ba56c242 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
@@ -276,9 +276,13 @@ AMDGPUMCCodeEmitter::getLitEncoding(const MCOperand &MO,
   case AMDGPU::OPERAND_REG_INLINE_C_INT16:
   case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
     return getLit16IntEncoding(static_cast<uint16_t>(Imm), STI);
+  case AMDGPU::OPERAND_REG_IMM_BF16:
   case AMDGPU::OPERAND_REG_IMM_FP16:
+  case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
   case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
+  case AMDGPU::OPERAND_REG_INLINE_C_BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_FP16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
     // FIXME Is this correct? What do inline immediates do on SI for f16 src
     // which does not have f16 support?
@@ -288,8 +292,11 @@ AMDGPUMCCodeEmitter::getLitEncoding(const MCOperand &MO,
   case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
     return AMDGPU::getInlineEncodingV2I16(static_cast<uint32_t>(Imm))
         .value_or(255);
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
     return AMDGPU::getInlineEncodingV2F16(static_cast<uint32_t>(Imm))
         .value_or(255);
diff --git a/llvm/lib/Target/AMDGPU/SIDefines.h b/llvm/lib/Target/AMDGPU/SIDefines.h
index 19596d53b45328..66b997eb180613 100644
--- a/llvm/lib/Target/AMDGPU/SIDefines.h
+++ b/llvm/lib/Target/AMDGPU/SIDefines.h
@@ -196,9 +196,12 @@ enum OperandType : unsigned {
   OPERAND_REG_IMM_INT16,
   OPERAND_REG_IMM_FP32,
   OPERAND_REG_IMM_FP64,
+  OPERAND_REG_IMM_BF16,
   OPERAND_REG_IMM_FP16,
+  OPERAND_REG_IMM_BF16_DEFERRED,
   OPERAND_REG_IMM_FP16_DEFERRED,
   OPERAND_REG_IMM_FP32_DEFERRED,
+  OPERAND_REG_IMM_V2BF16,
   OPERAND_REG_IMM_V2FP16,
   OPERAND_REG_IMM_V2INT16,
   OPERAND_REG_IMM_V2INT32,
@@ -208,10 +211,12 @@ enum OperandType : unsigned {
   OPERAND_REG_INLINE_C_INT16,
   OPERAND_REG_INLINE_C_INT32,
   OPERAND_REG_INLINE_C_INT64,
+  OPERAND_REG_INLINE_C_BF16,
   OPERAND_REG_INLINE_C_FP16,
   OPERAND_REG_INLINE_C_FP32,
   OPERAND_REG_INLINE_C_FP64,
   OPERAND_REG_INLINE_C_V2INT16,
+  OPERAND_REG_INLINE_C_V2BF16,
   OPERAND_REG_INLINE_C_V2FP16,
   OPERAND_REG_INLINE_C_V2INT32,
   OPERAND_REG_INLINE_C_V2FP32,
@@ -226,10 +231,12 @@ enum OperandType : unsigned {
   /// Operands with an AccVGPR register or inline constant
   OPERAND_REG_INLINE_AC_INT16,
   OPERAND_REG_INLINE_AC_INT32,
+  OPERAND_REG_INLINE_AC_BF16,
   OPERAND_REG_INLINE_AC_FP16,
   OPERAND_REG_INLINE_AC_FP32,
   OPERAND_REG_INLINE_AC_FP64,
   OPERAND_REG_INLINE_AC_V2INT16,
+  OPERAND_REG_INLINE_AC_V2BF16,
   OPERAND_REG_INLINE_AC_V2FP16,
   OPERAND_REG_INLINE_AC_V2INT32,
   OPERAND_REG_INLINE_AC_V2FP32,
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
index c7628bd354309c..9bb9809f83e34c 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
@@ -4185,9 +4185,17 @@ bool SIInstrInfo::isInlineConstant(const MachineOperand &MO,
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
     return AMDGPU::isInlinableLiteralV2F16(Imm);
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
+    return AMDGPU::isInlinableLiteralV2BF16(Imm);
+  case AMDGPU::OPERAND_REG_IMM_BF16:
   case AMDGPU::OPERAND_REG_IMM_FP16:
+  case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
   case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
+  case AMDGPU::OPERAND_REG_INLINE_C_BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_FP16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_FP16: {
     if (isInt<16>(Imm) || isUInt<16>(Imm)) {
       // A few special case instructions have 16-bit operands on subtargets
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
index 22599773d562cb..b0daec4a350eb3 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
@@ -1497,20 +1497,17 @@ class getVOPSrc0ForVT<ValueType VT, bit IsTrue16, bit IsFake16 = 1> {
   RegisterOperand ret =
     !if(VT.isFP,
       !if(!eq(VT.Size, 64),
-         VSrc_f64,
-         !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
-            !if(IsTrue16,
-              !if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128),
-              VSrc_f16
-            ),
-            !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
-               VSrc_v2f16,
-               !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
-                 AVSrc_64,
-                 VSrc_f32
+          VSrc_f64,
+          !if(!eq(VT.Value, f16.Value),
+              !if(IsTrue16, !if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128), VSrc_f16),
+              !if(!eq(VT.Value, bf16.Value),
+                 !if(IsTrue16, !if(IsFake16, VSrcFake16_bf16_Lo128, VSrcT_bf16_Lo128), VSrc_bf16),
+                 !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
+                     !if(!eq(VT.Value, v2f16.Value), VSrc_v2f16, VSrc_v2bf16),
+                     !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)), AVSrc_64, VSrc_f32)
+                  )
                )
-            )
-         )
+           )
        ),
        !if(!eq(VT.Size, 64),
           VSrc_b64,
@@ -1569,16 +1566,20 @@ class getVOP3SrcForVT<ValueType VT, bit IsTrue16 = 0> {
         !if(!eq(VT.Value, i1.Value),
            SSrc_i1,
            !if(VT.isFP,
-              !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
-                 !if(IsTrue16, VSrcT_f16, VSrc_f16),
-                 !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
-                    VSrc_v2f16,
-                    !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
-                      AVSrc_64,
-                      VSrc_f32
-                    )
-                 )
-              ),
+               !if(!eq(VT.Value, f16.Value),
+                   !if(IsTrue16, VSrcT_f16, VSrc_f16),
+                   !if(!eq(VT.Value, bf16.Value),
+                       !if(IsTrue16, VSrcT_bf16, VSrc_bf16),
+                       !if(!eq(VT.Value, v2f16.Value),
+                           VSrc_v2f16,
+                           !if(!eq(VT.Value, v2bf16.Value),
+                               VSrc_v2bf16,
+                               !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
+                                   AVSrc_64, VSrc_f32)
+                           )
+                       )
+                   )
+               ),
               !if(!eq(VT.Value, i16.Value),
                  !if(IsTrue16, VSrcT_b16, VSrc_b16),
                  !if(!eq(VT.Value, v2i16.Value),
@@ -1597,8 +1598,13 @@ class getVOP3DPPSrcForVT<ValueType VT> {
   RegisterOperand ret =
       !if (!eq(VT.Value, i1.Value), SSrc_i1,
            !if (VT.isFP,
-                !if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)), VCSrc_f16,
-                     !if (!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)), VCSrc_v2f16, VCSrc_f32)),
+                !if(!eq(VT.Value, f16.Value), VCSrc_f16,
+                    !if(!eq(VT.Value, bf16.Value), VCSrc_bf16,
+                        !if(!eq(VT.Value, v2f16.Value), VCSrc_v2f16,
+                            !if(!eq(VT.Value, v2bf16.Value), VCSrc_v2bf16, VCSrc_f32)
+                        )
+                    )
+                ),
                 !if (!eq(VT.Value, i16.Value), VCSrc_b16,
                      !if (!eq(VT.Value, v2i16.Value), VCSrc_v2b16,
                           VCSrc_b32))));
@@ -2528,7 +2534,7 @@ def VOP_V2I16_F32_F32 : VOPProfile <[v2i16, f32, f32, untyped]>;
 def VOP_V2I16_I32_I32 : VOPProfile <[v2i16, i32, i32, untyped]>;
 
 def VOP_F16_V2F16_V2F16_F16 : VOPProfile <[f16, v2f16, v2f16, f16]>;
-def VOP_I16_V2I16_V2I16_I16 : VOPProfile <[i16, v2i16, v2i16, i16]>;
+def VOP_BF16_V2BF16_V2BF16_BF16: VOPProfile <[bf16, v2bf16, v2bf16, bf16]>;
 def VOP_F32_V2I16_V2I16_F32 : VOPProfile <[f32, v2i16, v2i16, f32]>;
 
 def VOP_F32_V2F16_V2F16_V2F16 : VOPProfile <[f32, v2f16, v2f16, v2f16]>;
diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
index aabb6c29062114..f24e65304d2052 100644
--- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
@@ -1066,7 +1066,7 @@ multiclass AVRegClass<int numRegs, list<ValueType> regTypes,
     // Define the regular class.
     def "" : VRegClassBase<numRegs, regTypes, (add vregList, aregList)>;
 
-    // Define 2-aligned variant 
+    // Define 2-aligned variant
     def _Align2 : VRegClassBase<numRegs, regTypes,
                                 (add (decimate vregList, 2),
                                      (decimate aregList, 2))> {
@@ -1115,6 +1115,7 @@ class RegOrImmOperand <string RegisterClassName, string OperandTypeName,
 //===----------------------------------------------------------------------===//
 
 def SSrc_b16 : RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_INT16", "_Imm16">;
+def SSrc_bf16: RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_BF16", "_Imm16">;
 def SSrc_f16 : RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_FP16", "_Imm16">;
 def SSrc_b32 : RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_INT32", "_Imm32">;
 def SSrc_f32 : RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_FP32", "_Imm32">;
@@ -1142,6 +1143,7 @@ def SCSrc_b64 : RegOrImmOperand <"SReg_64", "OPERAND_REG_INLINE_C_INT64", "_Imm6
 
 // The current and temporary future default used case for VOP3.
 def VSrc_b16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_INT16", "_Imm16">;
+def VSrc_bf16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_BF16", "_Imm16">;
 def VSrc_f16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_FP16", "_Imm16">;
 
 // True16 VOP3 operands.
@@ -1149,6 +1151,10 @@ def VSrcT_b16 : RegOrImmOperand <"VS_16", "OPERAND_REG_IMM_INT16", "_Imm16"> {
   let EncoderMethod = "getMachineOpValueT16";
   let DecoderMethod = "decodeOperand_VSrcT16";
 }
+def VSrcT_bf16 : RegOrImmOperand <"VS_16", "OPERAND_REG_IMM_BF16", "_Imm16"> {
+  let EncoderMethod = "getMachineOpValueT16";
+  let DecoderMethod = "decodeOperand_VSrcT16";
+}
 def VSrcT_f16 : RegOrImmOperand <"VS_16", "OPERAND_REG_IMM_FP16", "_Imm16"> {
   let EncoderMethod = "getMachineOpValueT16";
   let DecoderMethod = "decodeOperand_VSrcT16";
@@ -1159,6 +1165,10 @@ def VSrcT_b16_Lo128 : RegOrImmOperand <"VS_16_Lo128", "OPERAND_REG_IMM_INT16", "
   let EncoderMethod = "getMachineOpValueT16Lo128";
   let DecoderMethod = "decodeOperand_VSrcT16_Lo128";
 }
+def VSrcT_bf16_Lo128 : RegOrImmOperand <"VS_16_Lo128", "OPERAND_REG_IMM_BF16", "_Imm16"> {
+  let EncoderMethod = "getMachineOpValueT16Lo128";
+  let DecoderMethod = "decodeOperand_VSrcT16_Lo128";
+}
 def VSrcT_f16_Lo128 : RegOrImmOperand <"VS_16_Lo128", "OPERAND_REG_IMM_FP16", "_Imm16"> {
   let EncoderMethod = "getMachineOpValueT16Lo128";
   let DecoderMethod = "decodeOperand_VSrcT16_Lo128";
@@ -1167,11 +1177,13 @@ def VSrcT_f16_Lo128 : RegOrImmOperand <"VS_16_Lo128", "OPERAND_REG_IMM_FP16", "_
 // The current and temporary future default used case for fake VOP1/2/C.
 // For VOP1,2,C True16 instructions. _Lo128 use first 128 32-bit VGPRs only.
 def VSrcFake16_b16_Lo128 : RegOrImmOperand <"VS_32_Lo128", "OPERAND_REG_IMM_INT16", "_Imm16">;
+def VSrcFake16_bf16_Lo128 : RegOrImmOperand <"VS_32_Lo128", "OPERAND_REG_IMM_BF16", "_Imm16">;
 def VSrcFake16_f16_Lo128 : RegOrImmOperand <"VS_32_Lo128", "OPERAND_REG_IMM_FP16", "_Imm16">;
 
 def VSrc_b32 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_INT32", "_Imm32">;
 def VSrc_f32 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_FP32", "_Imm32">;
 def VSrc_v2b16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_V2INT16", "_ImmV2I16">;
+def VSrc_v2bf16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_V2BF16", "_ImmV2F16">;
 def VSrc_v2f16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_V2FP16", "_ImmV2F16">;
 def VSrc_b64 : RegOrImmOperand <"VS_64", "OPERAND_REG_IMM_INT64", "_Imm64">;
 def VSrc_f64 : RegOrImmOperand <"VS_64", "OPERAND_REG_IMM_FP64", "_Imm64"> {
@@ -1185,9 +1197,13 @@ def VSrc_v2f32 : RegOrImmOperand <"VS_64", "OPERAND_REG_IMM_V2FP32", "_Imm32">;
 //  with FMAMK/FMAAK
 //===----------------------------------------------------------------------===//
 
+def VSrc_bf16_Deferred : RegOrImmOperand<"VS_32", "OPERAND_REG_IMM_BF16_DEFERRED", "_Deferred_Imm16">;
 def VSrc_f16_Deferred : RegOrImmOperand<"VS_32", "OPERAND_REG_IMM_FP16_DEFERRED", "_Deferred_Imm16">;
 def VSrc_f32_Deferred : RegOrImmOperand<"VS_32", "OPERAND_REG_IMM_FP32_DEFERRED", "_Deferred_Imm32">;
 
+def VSrcFake16_bf16_Lo128_Deferred : RegOrImmOperand<"VS_32_Lo128",
+                                                    "OPERAND_REG_IMM_BF16_DEFERRED",
+                                                    "_Deferred_Imm16">;
 def VSrcFake16_f16_Lo128_Deferred : RegOrImmOperand<"VS_32_Lo128",
                                                     "OPERAND_REG_IMM_FP16_DEFERRED",
                                                     "_Deferred_Imm16">;
@@ -1258,19 +1274,23 @@ def ARegSrc_32 : AVOperand<AGPR_32, "decodeSrcA9", "OPW32">;
 //===----------------------------------------------------------------------===//
 
 def VCSrc_b16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_INT16", "_Imm16">;
+def VCSrc_bf16: RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_BF16", "_Imm16">;
 def VCSrc_f16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_FP16", "_Imm16">;
 def VCSrc_b32 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_INT32", "_Imm32">;
 def VCSrc_f32 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_FP32", "_Imm32">;
 def VCSrc_v2b16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_V2INT16", "_ImmV2I16">;
+def VCSrc_v2bf16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_V2BF16", "_ImmV2F16">;
 def VCSrc_v2f16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_V2FP16", "_ImmV2F16">;
 
 //===----------------------------------------------------------------------===//
 //  VISrc_* Operands with a VGPR or an inline constant
 //===----------------------------------------------------------------------===//
 
+def VISrc_64_bf16 : RegOrImmOperand <"VReg_64", "OPERAND_REG_INLINE_C_BF16", "_Imm16">;
 def VISrc_64_f16 : RegOrImmOperand <"VReg_64", "OPERAND_REG_INLINE_C_FP16", "_Imm16">;
 def VISrc_64_b32 : RegOrImmOperand <"VReg_64", "OPERAND_REG_INLINE_C_INT32", "_Imm32">;
 def VISrc_64_f64 : RegOrImmOperand <"VReg_64", "OPERAND_REG_INLINE_C_FP64", "_Imm64">;
+def VISrc_128_bf16 : RegOrImmOperand <"VReg_128", "OPERAND_REG_INLINE_C_BF16", "_Imm16">;
 def VISrc_128_f16 : RegOrImmOperand <"VReg_128", "OPERAND_REG_INLINE_C_FP16", "_Imm16">;
 def VISrc_128_b32 : RegOrImmOperand <"VReg_128", "OPERAND_REG_INLINE_C_INT32", "_Imm32">;
 def VISrc_128_f32 : RegOrImmOperand <"VReg_128", "OPERAND_REG_INLINE_C_FP32", "_Imm32">;
diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp
index 800dfcf3076dd3..3a459d2192d382 100644
--- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp
@@ -2660,15 +2660,34 @@ bool isInlinableLiteral16(int16_t Literal, bool HasInv2Pi) {
     return true;
 
   uint16_t Val = static_cast<uint16_t>(Literal);
-  return Val == 0x3C00 || // 1.0
-         Val == 0xBC00 || // -1.0
-         Val == 0x3800 || // 0.5
-         Val == 0xB800 || // -0.5
-         Val == 0x4000 || // 2.0
-         Val == 0xC000 || // -2.0
-         Val == 0x4400 || // 4.0
-         Val == 0xC400 || // -4.0
-         Val == 0x3118;   // 1/2pi
+
+  // FP16
+  if (Val == 0x3C00 || // 1.0
+      Val == 0xBC00 || // -1.0
+      Val == 0x3800 || // 0.5
+      Val == 0xB800 || // -0.5
+      Val == 0x4000 || // 2.0
+      Val == 0xC000 || // -2.0
+      Val == 0x4400 || // 4.0
+      Val == 0xC400 || // -4.0
+      Val == 0x3118    // 1/2pi
+  )
+    return true;
+
+  // BF16
+  if (Val == 0x3F80 || // 1.0
+      Val == 0xBF80 || // -1.0
+      Val == 0x3F00 || // 0.5
+      Val == 0xBF00 || // -0.5
+      Val == 0x4000 || // 2.0
+      Val == 0xC000 || // -2.0
+      Val == 0x4080 || // 4.0
+      Val == 0xC080 || // -4.0
+      Val == 0x3E22    // 1/2pi
+  )
+    return true;
+
+  return false;
 }
 
 std::optional<unsigned> getInlineEncodingV216(bool IsFloat, uint32_t Literal) {
@@ -2730,6 +2749,12 @@ std::optional<unsigned> getInlineEncodingV2I16(uint32_t Literal) {
   return getInlineEncodingV216(false, Literal);
 }
 
+// Encoding of the literal as an inline constant for a V_PK_*_BF16 instruction
+// or nullopt.
+std::optional<unsigned> getInlineEncodingV2BF16(uint32_t Literal) {
+  return getInlineEncodingV216(true, Literal);
+}
+
 // Encoding of the literal as an inline constant for a V_PK_*_F16 instruction
 // or nullopt.
 std::optional<unsigned> getInlineEncodingV2F16(uint32_t Literal) {
@@ -2757,6 +2782,11 @@ bool isInlinableLiteralV2I16(uint32_t Literal) {
   return getInlineEncodingV2I16(Literal).has_value();
 }
 
+// Whether the given literal can be inlined for a V_PK_*_BF16 instruction.
+bool isInlinableLiteralV2BF16(uint32_t Literal) {
+  return getInlineEncodingV2BF16(Literal).has_value();
+}
+
 // Whether the given literal can be inlined for a V_PK_*_F16 instruction.
 bool isInlinableLiteralV2F16(uint32_t Literal) {
   return getInlineEncodingV2F16(Literal).has_value();
diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h
index b56025f55519a5..92e8cbf35fc3e1 100644
--- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h
+++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h
@@ -1335,17 +1335,24 @@ inline unsigned getOperandSize(const MCOperandInfo &OpInfo) {
     return 8;
 
   case AMDGPU::OPERAND_REG_IMM_INT16:
+  case AMDGPU::OPERAND_REG_IMM_BF16:
   case AMDGPU::OPERAND_REG_IMM_FP16:
+  case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
   case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
   case AMDGPU::OPERAND_REG_INLINE_C_INT16:
+  case AMDGPU::OPERAND_REG_INLINE_C_BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_FP16:
   case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
   case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
   case AMDGPU::OPERAND_REG_IMM_V2INT16:
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
     return 2;
 
@@ -1373,12 +1380,17 @@ bool isInlinableLiteral64(int64_t Literal, bool HasInv2Pi);
 LLVM_READNONE
 bool isInlinableLiteral32(int32_t Literal, bool HasInv2Pi);
 
+enum Literal16Kind { I16, FP16, BF16 };
+
 LLVM_READNONE
 bool isInlinableLiteral16(int16_t Literal, bool HasInv2Pi);
 
 LLVM_READNONE
 std::optional<unsigned> getInlineEncodingV2I16(uint32_t Literal);
 
+LLVM_READNONE
+std::optional<unsigned> getInlineEncodingV2BF16(uint32_t Literal);
+
 LLVM_READNONE
 std::optional<unsigned> getInlineEncodingV2F16(uint32_t Literal);
 
@@ -1388,6 +1400,9 @@ bool isInlinableLiteralV216(uint32_t Literal, uint8_t OpType);
 LLVM_READNONE
 bool isInlinableLiteralV2I16(uint32_t Literal);
 
+LLVM_READNONE
+bool isInlinableLiteralV2BF16(uint32_t Literal);
+
 LLVM_READNONE
 bool isInlinableLiteralV2F16(uint32_t Literal);
 
diff --git a/llvm/lib/Target/AMDGPU/VOP3Instructions.td b/llvm/lib/Target/AMDGPU/VOP3Instructions.td
index 8d965d3b9041d5..35cffa22f45929 100644
--- a/llvm/lib/Target/AMDGPU/VOP3Instructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3Instructions.td
@@ -904,7 +904,7 @@ let SubtargetPredicate = isGFX12Plus, ReadsModeReg = 0 in {
 
 let SubtargetPredicate = HasDot9Insts, IsDOT=1 in {
   defm V_DOT2_F16_F16 :   VOP3Inst<"v_dot2_f16_f16",   VOP3_DOT_Profile<VOP_F16_V2F16_V2F16_F16>, int_amdgcn_fdot2_f16_f16>;
-  defm V_DOT2_BF16_BF16 : VOP3Inst<"v_dot2_bf16_bf16", VOP3_DOT_Profile<VOP_I16_V2I16_V2I16_I16>, int_amdgcn_fdot2_bf16_bf16>;
+  defm V_DOT2_BF16_BF16 : VOP3Inst<"v_dot2_bf16_bf16", VOP3_DOT_Profile<VOP_BF16_V2BF16_V2BF16_BF16>, int_amdgcn_fdot2_bf16_bf16>;
 }
 
 class VOP_Pseudo_Scalar<RegisterClass Dst, RegisterOperand SrcOp,
diff --git a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.bf16.bf16.ll b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.bf16.bf16.ll
index 645e00b6a4a819..b1b44983e9daf8 100644
--- a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.bf16.bf16.ll
+++ b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.bf16.bf16.ll
@@ -2,7 +2,7 @@
 ; RUN: llc -mtriple=amdgcn -mcpu=gfx1100 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=GFX11,SDAG-GFX11
 ; RUN: llc -global-isel -mtriple=amdgcn -mcpu=gfx1100 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=GFX11,GISEL-GFX11
 
-declare i16 @llvm.amdgcn.fdot2.bf16.bf16(<2 x i16> %a, <2 x i16> %b, i16 %c)
+declare bfloat @llvm.amdgcn.fdot2.bf16.bf16(<2 x bfloat> %a, <2 x bfloat> %b, bfloat %c)
 
 define amdgpu_kernel void @test_llvm_amdgcn_fdot2_bf16_bf16(
 ; GFX11-LABEL: test_llvm_amdgcn_fdot2_bf16_bf16:
@@ -24,11 +24,11 @@ define amdgpu_kernel void @test_llvm_amdgcn_fdot2_bf16_bf16(
     ptr addrspace(1) %b,
     ptr addrspace(1) %c) {
 entry:
-  %a.val = load <2 x i16>, ptr addrspace(1) %a
-  %b.val = load <2 x i16>, ptr addrspace(1) %b
-  %c.val = load i16, ptr addrspace(1) %c
-  %r.val = call i16 @llvm.amdgcn.fdot2.bf16.bf16(<2 x i16> %a.val, <2 x i16> %b.val, i16 %c.val)
-  store i16 %r.val, ptr addrspace(1) %r
+  %a.val = load <2 x bfloat>, ptr addrspace(1) %a
+  %b.val = load <2 x bfloat>, ptr addrspace(1) %b
+  %c.val = load bfloat, ptr addrspace(1) %c
+  %r.val = call bfloat @llvm.amdgcn.fdot2.bf16.bf16(<2 x bfloat> %a.val, <2 x bfloat> %b.val, bfloat %c.val)
+  store bfloat %r.val, ptr addrspace(1) %r
   ret void
 }
 
@@ -61,14 +61,14 @@ define amdgpu_kernel void @test_llvm_amdgcn_fdot2_bf16_bf16_dpp(
     ptr addrspace(5) %b,
     ptr addrspace(5) %c) {
 entry:
-  %a.val = load <2 x i16>, ptr addrspace(5) %a
-  %b.val = load <2 x i16>, ptr addrspace(5) %b
-  %c.val = load i16, ptr addrspace(5) %c
-  %a.val.i32 = bitcast <2 x i16> %a.val to i32
+  %a.val = load <2 x bfloat>, ptr addrspace(5) %a
+  %b.val = load <2 x bfloat>, ptr addrspace(5) %b
+  %c.val = load bfloat, ptr addrspace(5) %c
+  %a.val.i32 = bitcast <2 x bfloat> %a.val to i32
   %dpp = call i32 @llvm.amdgcn.update.dpp.i32(i32 %a.val.i32, i32 %a.val.i32, i32 1, i32 15, i32 15, i1 1)
-  %a.val.dpp.v2i16 = bitcast i32 %dpp to <2 x i16>
-  %r.val = call i16 @llvm.amdgcn.fdot2.bf16.bf16(<2 x i16> %a.val.dpp.v2i16, <2 x i16> %b.val, i16 %c.val)
-  store i16 %r.val, ptr addrspace(5) %r
+  %a.val.dpp.v2bfloat = bitcast i32 %dpp to <2 x bfloat>
+  %r.val = call bfloat @llvm.amdgcn.fdot2.bf16.bf16(<2 x bfloat> %a.val.dpp.v2bfloat, <2 x bfloat> %b.val, bfloat %c.val)
+  store bfloat %r.val, ptr addrspace(5) %r
   ret void
 }
 
@@ -79,17 +79,17 @@ define amdgpu_ps void @test_llvm_amdgcn_fdot2_bf16_bf16_sis(
 ; GFX11:       ; %bb.0: ; %entry
 ; GFX11-NEXT:    v_mov_b32_e32 v2, s1
 ; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_1)
-; GFX11-NEXT:    v_dot2_bf16_bf16 v2, s0, 0x10001, v2
+; GFX11-NEXT:    v_dot2_bf16_bf16 v2, s0, 0x3f803f80, v2
 ; GFX11-NEXT:    global_store_b16 v[0:1], v2, off
 ; GFX11-NEXT:    s_nop 0
 ; GFX11-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
 ; GFX11-NEXT:    s_endpgm
     ptr addrspace(1) %r,
-    <2 x i16> inreg %a,
-    i16 inreg %c) {
+    <2 x bfloat> inreg %a,
+    bfloat inreg %c) {
 entry:
-  %r.val = call i16 @llvm.amdgcn.fdot2.bf16.bf16(<2 x i16> %a, <2 x i16> <i16 1, i16 1>, i16 %c)
-  store i16 %r.val, ptr addrspace(1) %r
+  %r.val = call bfloat @llvm.amdgcn.fdot2.bf16.bf16(<2 x bfloat> %a, <2 x bfloat> <bfloat 1.0, bfloat 1.0>, bfloat %c)
+  store bfloat %r.val, ptr addrspace(1) %r
   ret void
 }
 
diff --git a/llvm/test/MC/AMDGPU/bf16_imm.s b/llvm/test/MC/AMDGPU/bf16_imm.s
new file mode 100644
index 00000000000000..f92a6b9a649e42
--- /dev/null
+++ b/llvm/test/MC/AMDGPU/bf16_imm.s
@@ -0,0 +1,8 @@
+// RUN: llvm-mc -arch=amdgcn -mcpu=gfx1100 -show-encoding %s | FileCheck %s
+// RUN: llvm-mc -arch=amdgcn -mcpu=gfx1200 -show-encoding %s | FileCheck %s
+
+v_dot2_bf16_bf16 v5, v1, v2, 100.0
+// CHECK: v_dot2_bf16_bf16 v5, v1, v2, 0x42c8 ; encoding: [0x05,0x00,0x67,0xd6,0x01,0x05,0xfe,0x03,0xc8,0x42,0x00,0x00]
+
+v_dot2_bf16_bf16 v5, v1, v2, 1.0
+// CHECK: v_dot2_bf16_bf16 v5, v1, v2, 1.0 ; encoding: [0x05,0x00,0x67,0xd6,0x01,0x05,0xfe,0x03,0x80,0x3f,0x00,0x00]



More information about the llvm-commits mailing list