[clang] [llvm] [RFC][AMDGPU] Use `bf16` instead of `i16` for bfloat (PR #80908)

via cfe-commits cfe-commits at lists.llvm.org
Thu Feb 8 09:13:50 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-ir

Author: Shilei Tian (shiltian)

<details>
<summary>Changes</summary>

Currently it looks like we generally use `i16` to represent `bf16` in those tablegen
files. I'm not sure of the reason behind it. My wild guess is the type `bf16` was
not available when we enabled the support. This patch is trying to use `bf16`
directly in those tablegen files, aiming at fixing #<!-- -->79369.

---

Patch is 32.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/80908.diff


14 Files Affected:

- (modified) clang/lib/CodeGen/CGBuiltin.cpp (-4) 
- (modified) llvm/include/llvm/IR/IntrinsicsAMDGPU.td (+4-4) 
- (modified) llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp (+3-2) 
- (modified) llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp (+66) 
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp (+10) 
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp (+7) 
- (modified) llvm/lib/Target/AMDGPU/SIDefines.h (+7) 
- (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.cpp (+7) 
- (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.td (+32-26) 
- (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+21-1) 
- (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h (+7) 
- (modified) llvm/lib/Target/AMDGPU/VOP3Instructions.td (+1-1) 
- (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.bf16.bf16.ll (+18-18) 
- (added) llvm/test/MC/AMDGPU/bf16_imm.s (+8) 


``````````diff
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index a7a410dab1a018..daf651917f2a96 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -5908,8 +5908,6 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
           }
         }
 
-        assert(ArgValue->getType()->canLosslesslyBitCastTo(PTy) &&
-               "Must be able to losslessly bit cast to param");
         // Cast vector type (e.g., v256i32) to x86_amx, this only happen
         // in amx intrinsics.
         if (PTy->isX86_AMXTy())
@@ -5939,8 +5937,6 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
         }
       }
 
-      assert(V->getType()->canLosslesslyBitCastTo(RetTy) &&
-             "Must be able to losslessly bit cast result type");
       // Cast x86_amx to vector type (e.g., v256i32), this only happen
       // in amx intrinsics.
       if (V->getType()->isX86_AMXTy())
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index 202fa4e8f4ea81..6795fb7aa0edb8 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -2819,11 +2819,11 @@ def int_amdgcn_fdot2_f16_f16 :
 def int_amdgcn_fdot2_bf16_bf16 :
   ClangBuiltin<"__builtin_amdgcn_fdot2_bf16_bf16">,
   DefaultAttrsIntrinsic<
-    [llvm_i16_ty],   // %r
+    [llvm_bfloat_ty],   // %r
     [
-      llvm_v2i16_ty, // %a
-      llvm_v2i16_ty, // %b
-      llvm_i16_ty    // %c
+      llvm_v2bf16_ty, // %a
+      llvm_v2bf16_ty, // %b
+      llvm_bfloat_ty    // %c
     ],
     [IntrNoMem, IntrSpeculatable]
   >;
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index c1d8e890a66edb..828229f3e569e3 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -1562,8 +1562,9 @@ bool IRTranslator::translateBitCast(const User &U,
 
 bool IRTranslator::translateCast(unsigned Opcode, const User &U,
                                  MachineIRBuilder &MIRBuilder) {
-  if (U.getType()->getScalarType()->isBFloatTy() ||
-      U.getOperand(0)->getType()->getScalarType()->isBFloatTy())
+  if (Opcode != TargetOpcode::G_BITCAST &&
+      (U.getType()->getScalarType()->isBFloatTy() ||
+       U.getOperand(0)->getType()->getScalarType()->isBFloatTy()))
     return false;
   Register Op = getOrCreateVReg(*U.getOperand(0));
   Register Res = getOrCreateVReg(U);
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index a94da992b33859..d6d96c251f7e30 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -475,6 +475,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
 
   bool isSSrcF64() const { return isSCSrc_b64() || isLiteralImm(MVT::f64); }
 
+  bool isSSrc_bf16() const { return isSCSrcB16() || isLiteralImm(MVT::bf16); }
+
   bool isSSrc_f16() const { return isSCSrcB16() || isLiteralImm(MVT::f16); }
 
   bool isSSrcV2F16() const {
@@ -541,22 +543,40 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     return isRegOrInlineNoMods(AMDGPU::VS_64RegClassID, MVT::f64);
   }
 
+  bool isVCSrcTBF16() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::bf16);
+  }
+
   bool isVCSrcTF16() const {
     return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::f16);
   }
 
+  bool isVCSrcTBF16_Lo128() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::bf16);
+  }
+
   bool isVCSrcTF16_Lo128() const {
     return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::f16);
   }
 
+  bool isVCSrcFake16BF16_Lo128() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::bf16);
+  }
+
   bool isVCSrcFake16F16_Lo128() const {
     return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::f16);
   }
 
+  bool isVCSrc_bf16() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::bf16);
+  }
+
   bool isVCSrc_f16() const {
     return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::f16);
   }
 
+  bool isVCSrc_v2bf16() const { return isVCSrc_bf16(); }
+
   bool isVCSrc_v2f16() const { return isVCSrc_f16(); }
 
   bool isVSrc_b32() const {
@@ -597,18 +617,34 @@ class AMDGPUOperand : public MCParsedAsmOperand {
 
   bool isVSrc_f64() const { return isVCSrcF64() || isLiteralImm(MVT::f64); }
 
+  bool isVSrcT_bf16() const { return isVCSrcTBF16() || isLiteralImm(MVT::bf16); }
+
   bool isVSrcT_f16() const { return isVCSrcTF16() || isLiteralImm(MVT::f16); }
 
+  bool isVSrcT_bf16_Lo128() const {
+    return isVCSrcTBF16_Lo128() || isLiteralImm(MVT::bf16);
+  }
+
   bool isVSrcT_f16_Lo128() const {
     return isVCSrcTF16_Lo128() || isLiteralImm(MVT::f16);
   }
 
+  bool isVSrcFake16_bf16_Lo128() const {
+    return isVCSrcFake16BF16_Lo128() || isLiteralImm(MVT::bf16);
+  }
+
   bool isVSrcFake16_f16_Lo128() const {
     return isVCSrcFake16F16_Lo128() || isLiteralImm(MVT::f16);
   }
 
+  bool isVSrc_bf16() const { return isVCSrc_bf16() || isLiteralImm(MVT::bf16); }
+
   bool isVSrc_f16() const { return isVCSrc_f16() || isLiteralImm(MVT::f16); }
 
+  bool isVSrc_v2bf16() const {
+    return isVSrc_bf16() || isLiteralImm(MVT::v2bf16);
+  }
+
   bool isVSrc_v2f16() const { return isVSrc_f16() || isLiteralImm(MVT::v2f16); }
 
   bool isVISrcB32() const {
@@ -635,6 +671,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     return isVISrcF16() || isVISrcB32();
   }
 
+  bool isVISrc_64_bf16() const {
+    return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::bf16);
+  }
+
   bool isVISrc_64_f16() const {
     return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::f16);
   }
@@ -803,6 +843,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     return isAISrc_128F16() || isAISrc_128_b32();
   }
 
+  bool isVISrc_128_bf16() const {
+    return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::bf16);
+  }
+
   bool isVISrc_128_f16() const {
     return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::f16);
   }
@@ -1890,6 +1934,14 @@ static const fltSemantics *getOpFltSemantics(uint8_t OperandType) {
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
   case AMDGPU::OPERAND_KIMM16:
     return &APFloat::IEEEhalf();
+  case AMDGPU::OPERAND_REG_IMM_BF16:
+  case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
+  case AMDGPU::OPERAND_REG_INLINE_C_BF16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
+    return &APFloat::BFloat();
   default:
     llvm_unreachable("unsupported fp type");
   }
@@ -2186,17 +2238,24 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
     case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
     case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
     case AMDGPU::OPERAND_REG_IMM_INT16:
+    case AMDGPU::OPERAND_REG_IMM_BF16:
     case AMDGPU::OPERAND_REG_IMM_FP16:
+    case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
     case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
     case AMDGPU::OPERAND_REG_INLINE_C_INT16:
+    case AMDGPU::OPERAND_REG_INLINE_C_BF16:
     case AMDGPU::OPERAND_REG_INLINE_C_FP16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
+    case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
     case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
+    case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
     case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
     case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
+    case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
     case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
     case AMDGPU::OPERAND_REG_IMM_V2INT16:
+    case AMDGPU::OPERAND_REG_IMM_V2BF16:
     case AMDGPU::OPERAND_REG_IMM_V2FP16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2FP32:
     case AMDGPU::OPERAND_REG_IMM_V2FP32:
@@ -2240,6 +2299,7 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
   case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
   case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
   case AMDGPU::OPERAND_REG_IMM_V2INT16:
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
   case AMDGPU::OPERAND_REG_IMM_V2FP32:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP32:
@@ -2277,11 +2337,15 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
     return;
 
   case AMDGPU::OPERAND_REG_IMM_INT16:
+  case AMDGPU::OPERAND_REG_IMM_BF16:
   case AMDGPU::OPERAND_REG_IMM_FP16:
+  case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
   case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
   case AMDGPU::OPERAND_REG_INLINE_C_INT16:
+  case AMDGPU::OPERAND_REG_INLINE_C_BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_FP16:
   case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
     if (isSafeTruncation(Val, 16) &&
         AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
@@ -2296,8 +2360,10 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
     return;
 
   case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16: {
     assert(isSafeTruncation(Val, 16));
     assert(AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
index abfa4a3531e8e1..96a0168f37e405 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
@@ -521,8 +521,11 @@ void AMDGPUInstPrinter::printImmediateV216(uint32_t Imm, uint8_t OpType,
     if (printImmediateFloat32(Imm, STI, O))
       return;
     break;
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
     if (isUInt<16>(Imm) &&
         printImmediateFloat16(static_cast<uint16_t>(Imm), STI, O))
@@ -792,17 +795,24 @@ void AMDGPUInstPrinter::printRegularOperand(const MCInst *MI, unsigned OpNo,
     case AMDGPU::OPERAND_REG_IMM_INT16:
       printImmediateInt16(Op.getImm(), STI, O);
       break;
+    case AMDGPU::OPERAND_REG_INLINE_C_BF16:
     case AMDGPU::OPERAND_REG_INLINE_C_FP16:
+    case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
     case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
+    case AMDGPU::OPERAND_REG_IMM_BF16:
     case AMDGPU::OPERAND_REG_IMM_FP16:
+    case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
     case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
       printImmediate16(Op.getImm(), STI, O);
       break;
     case AMDGPU::OPERAND_REG_IMM_V2INT16:
+    case AMDGPU::OPERAND_REG_IMM_V2BF16:
     case AMDGPU::OPERAND_REG_IMM_V2FP16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
     case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
+    case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
+    case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
     case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
       printImmediateV216(Op.getImm(), OpTy, STI, O);
       break;
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
index 11f5e456e8d348..9ec174ba56c242 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
@@ -276,9 +276,13 @@ AMDGPUMCCodeEmitter::getLitEncoding(const MCOperand &MO,
   case AMDGPU::OPERAND_REG_INLINE_C_INT16:
   case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
     return getLit16IntEncoding(static_cast<uint16_t>(Imm), STI);
+  case AMDGPU::OPERAND_REG_IMM_BF16:
   case AMDGPU::OPERAND_REG_IMM_FP16:
+  case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
   case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
+  case AMDGPU::OPERAND_REG_INLINE_C_BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_FP16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
     // FIXME Is this correct? What do inline immediates do on SI for f16 src
     // which does not have f16 support?
@@ -288,8 +292,11 @@ AMDGPUMCCodeEmitter::getLitEncoding(const MCOperand &MO,
   case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
     return AMDGPU::getInlineEncodingV2I16(static_cast<uint32_t>(Imm))
         .value_or(255);
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
     return AMDGPU::getInlineEncodingV2F16(static_cast<uint32_t>(Imm))
         .value_or(255);
diff --git a/llvm/lib/Target/AMDGPU/SIDefines.h b/llvm/lib/Target/AMDGPU/SIDefines.h
index 19596d53b45328..66b997eb180613 100644
--- a/llvm/lib/Target/AMDGPU/SIDefines.h
+++ b/llvm/lib/Target/AMDGPU/SIDefines.h
@@ -196,9 +196,12 @@ enum OperandType : unsigned {
   OPERAND_REG_IMM_INT16,
   OPERAND_REG_IMM_FP32,
   OPERAND_REG_IMM_FP64,
+  OPERAND_REG_IMM_BF16,
   OPERAND_REG_IMM_FP16,
+  OPERAND_REG_IMM_BF16_DEFERRED,
   OPERAND_REG_IMM_FP16_DEFERRED,
   OPERAND_REG_IMM_FP32_DEFERRED,
+  OPERAND_REG_IMM_V2BF16,
   OPERAND_REG_IMM_V2FP16,
   OPERAND_REG_IMM_V2INT16,
   OPERAND_REG_IMM_V2INT32,
@@ -208,10 +211,12 @@ enum OperandType : unsigned {
   OPERAND_REG_INLINE_C_INT16,
   OPERAND_REG_INLINE_C_INT32,
   OPERAND_REG_INLINE_C_INT64,
+  OPERAND_REG_INLINE_C_BF16,
   OPERAND_REG_INLINE_C_FP16,
   OPERAND_REG_INLINE_C_FP32,
   OPERAND_REG_INLINE_C_FP64,
   OPERAND_REG_INLINE_C_V2INT16,
+  OPERAND_REG_INLINE_C_V2BF16,
   OPERAND_REG_INLINE_C_V2FP16,
   OPERAND_REG_INLINE_C_V2INT32,
   OPERAND_REG_INLINE_C_V2FP32,
@@ -226,10 +231,12 @@ enum OperandType : unsigned {
   /// Operands with an AccVGPR register or inline constant
   OPERAND_REG_INLINE_AC_INT16,
   OPERAND_REG_INLINE_AC_INT32,
+  OPERAND_REG_INLINE_AC_BF16,
   OPERAND_REG_INLINE_AC_FP16,
   OPERAND_REG_INLINE_AC_FP32,
   OPERAND_REG_INLINE_AC_FP64,
   OPERAND_REG_INLINE_AC_V2INT16,
+  OPERAND_REG_INLINE_AC_V2BF16,
   OPERAND_REG_INLINE_AC_V2FP16,
   OPERAND_REG_INLINE_AC_V2INT32,
   OPERAND_REG_INLINE_AC_V2FP32,
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
index c7628bd354309c..fcb2a6f1f3d75d 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
@@ -4181,13 +4181,20 @@ bool SIInstrInfo::isInlineConstant(const MachineOperand &MO,
   case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
     return AMDGPU::isInlinableLiteralV2I16(Imm);
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
     return AMDGPU::isInlinableLiteralV2F16(Imm);
+  case AMDGPU::OPERAND_REG_IMM_BF16:
   case AMDGPU::OPERAND_REG_IMM_FP16:
+  case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
   case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
+  case AMDGPU::OPERAND_REG_INLINE_C_BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_FP16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_FP16: {
     if (isInt<16>(Imm) || isUInt<16>(Imm)) {
       // A few special case instructions have 16-bit operands on subtargets
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
index 22599773d562cb..b0daec4a350eb3 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
@@ -1497,20 +1497,17 @@ class getVOPSrc0ForVT<ValueType VT, bit IsTrue16, bit IsFake16 = 1> {
   RegisterOperand ret =
     !if(VT.isFP,
       !if(!eq(VT.Size, 64),
-         VSrc_f64,
-         !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
-            !if(IsTrue16,
-              !if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128),
-              VSrc_f16
-            ),
-            !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
-               VSrc_v2f16,
-               !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
-                 AVSrc_64,
-                 VSrc_f32
+          VSrc_f64,
+          !if(!eq(VT.Value, f16.Value),
+              !if(IsTrue16, !if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128), VSrc_f16),
+              !if(!eq(VT.Value, bf16.Value),
+                 !if(IsTrue16, !if(IsFake16, VSrcFake16_bf16_Lo128, VSrcT_bf16_Lo128), VSrc_bf16),
+                 !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
+                     !if(!eq(VT.Value, v2f16.Value), VSrc_v2f16, VSrc_v2bf16),
+                     !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)), AVSrc_64, VSrc_f32)
+                  )
                )
-            )
-         )
+           )
        ),
        !if(!eq(VT.Size, 64),
           VSrc_b64,
@@ -1569,16 +1566,20 @@ class getVOP3SrcForVT<ValueType VT, bit IsTrue16 = 0> {
         !if(!eq(VT.Value, i1.Value),
            SSrc_i1,
            !if(VT.isFP,
-              !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
-                 !if(IsTrue16, VSrcT_f16, VSrc_f16),
-                 !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
-                    VSrc_v2f16,
-                    !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
-                      AVSrc_64,
-                      VSrc_f32
-                    )
-                 )
-              ),
+               !if(!eq(VT.Value, f16.Value),
+                   !if(IsTrue16, VSrcT_f16, VSrc_f16),
+                   !if(!eq(VT.Value, bf16.Value),
+                       !if(IsTrue16, VSrcT_bf16, VSrc_bf16),
+                       !if(!eq(VT.Value, v2f16.Value),
+                           VSrc_v2f16,
+                           !if(!eq(VT.Value, v2bf16.Value),
+                               VSrc_v2bf16,
+                               !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
+                                   AVSrc_64, VSrc_f32)
+                           )
+                       )
+                   )
+               ),
               !if(!eq(VT.Value, i16.Value),
                  !if(IsTrue16, VSrcT_b16, VSrc_b16),
                  !if(!eq(VT.Value, v2i16.Value),
@@ -1597,8 +1598,13 @@ class getVOP3DPPSrcForVT<ValueType VT> {
   RegisterOperand ret =
       !if (!eq(VT.Value, i1.Value), SSrc_i1,
            !if (VT.isFP,
-                !if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)), VCSrc_f16,
-                     !if (!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)), VCSrc_v2f16, VCSrc_f32)),
+                !if(!eq(VT.Value, f16.Value), VCSrc_f16,
+                    !if(!eq(VT.Value, bf16.Value), VCSrc_bf16,
+                        !if(!eq(VT.Value, v2f16.Value), VCSrc_v2f16,
+                            !if(!eq(VT.Value, v2bf16.Value), VCSrc_v2bf16, VCSrc_f32)
+                        )
+                    )
+                ),
                 !if (!eq(VT.Value, i16.Value), VCSrc_b16,
                      !if (!eq(VT.Value, v2i16.Value), VCSrc_v2b16,
                           VCSrc_b32))));
@@ -2528,7 +2534,7 @@ def VOP_V2I16_F32_F32 : VOPProfile <[v2i16, f32, f32, untyped]>;
 def VOP_V2I16_I32_I32 : VOPProfile <[v2i16, i32, i32, untyped]>;
 
 def VOP_F16_V2F16_V2F16_F16 : VOPProfile <[f16, v2f16, v2f16, f16]>;
-def VOP_I16_V2I16_V2I16_I16 : VOPProfile <[i16, v2i16, v2i16, i16]>;
+def VOP_BF16_V2BF16_V2BF16_BF16: VOPProfile <[bf16, v2bf16, v2bf16, bf16]>;
 def VOP_F32_V2I16_V2I16_F32 : VOPProfile <[f32, v2i16, v2i16, f32]>;
 
 def VOP_F32_V2F16_V2F16_V2F16 : VOPProfile <[f32, v2f16, v2f16, v2f16]>;
diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
index aabb6c29062114..f24e65304d2052 100644
--- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
@@ -1...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/80908


More information about the cfe-commits mailing list