[llvm] [RFC][WIP][AMDGPU] Use `bf16` instead of `i16` for bfloat (PR #80908)

Shilei Tian via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 6 14:03:33 PST 2024


https://github.com/shiltian created https://github.com/llvm/llvm-project/pull/80908

Currently it looks like we generally use `i16` to represent `bf16` in those tablegen
files. I'm not sure of the reason behind it. My wild guess is the type `bf16` was
not available when we enabled the support. This patch is trying to use `bf16`
directly in those tablegen files, aiming at fixing #79369. Of course for #79369
a workaround can be to treat all `INT16` variants as `BFloat` in `getOpFltSemantics`,
but it doesn't look good IMHO.

Since I'm fairly new to AMDGPU backend, I'd appreciate it if you can point out
where I don't understand correctly.


>From f8de3422157be641ce88aed8204f85a8ee2a070e Mon Sep 17 00:00:00 2001
From: Shilei Tian <i at tianshilei.me>
Date: Tue, 6 Feb 2024 16:55:53 -0500
Subject: [PATCH] [RFC][WIP][AMDGPU] Use `bf16` instead of `i16` for bfloat

Currently it looks like we generally use `i16` to represent `bf16` in those tablegen
files. I'm not sure of the reason behind it. My wild guess is the type `bf16` was
not available when we enabled the support. This patch is trying to use `bf16`
directly in those tablegen files, aiming at fixing #79369. Of course for #79369
a workaround can be to treat all `INT16` variants as `BFloat` in `getOpFltSemantics`,
but it doesn't look good IMHO.

Since I'm fairly new to AMDGPU backend, I'd appreciate it if you can point out
where I don't understand correctly.
---
 llvm/include/llvm/IR/IntrinsicsAMDGPU.td      | 12 ++---
 .../AMDGPU/AsmParser/AMDGPUAsmParser.cpp      | 44 +++++++++++++++++++
 llvm/lib/Target/AMDGPU/SIDefines.h            |  7 +++
 llvm/lib/Target/AMDGPU/SIInstrInfo.td         | 27 +++++-------
 llvm/lib/Target/AMDGPU/SIRegisterInfo.td      | 22 +++++++++-
 llvm/lib/Target/AMDGPU/VOP3Instructions.td    |  2 +-
 llvm/lib/Target/AMDGPU/VOP3PInstructions.td   |  2 +-
 7 files changed, 92 insertions(+), 24 deletions(-)

diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index 202fa4e8f4ea81..0f29653f1f5bec 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -2819,11 +2819,11 @@ def int_amdgcn_fdot2_f16_f16 :
 def int_amdgcn_fdot2_bf16_bf16 :
   ClangBuiltin<"__builtin_amdgcn_fdot2_bf16_bf16">,
   DefaultAttrsIntrinsic<
-    [llvm_i16_ty],   // %r
+    [llvm_bfloat_ty],   // %r
     [
-      llvm_v2i16_ty, // %a
-      llvm_v2i16_ty, // %b
-      llvm_i16_ty    // %c
+      llvm_v2bf16_ty, // %a
+      llvm_v2bf16_ty, // %b
+      llvm_bfloat_ty    // %c
     ],
     [IntrNoMem, IntrSpeculatable]
   >;
@@ -2835,8 +2835,8 @@ def int_amdgcn_fdot2_f32_bf16 :
   DefaultAttrsIntrinsic<
     [llvm_float_ty], // %r
     [
-      llvm_v2i16_ty, // %a
-      llvm_v2i16_ty, // %b
+      llvm_v2bf16_ty, // %a
+      llvm_v2bf16_ty, // %b
       llvm_float_ty, // %c
       llvm_i1_ty     // %clamp
     ],
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index 225e781588668f..f2adcf38b2ed7f 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -474,6 +474,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
 
   bool isSSrcF64() const { return isSCSrc_b64() || isLiteralImm(MVT::f64); }
 
+  bool isSSrc_bf16() const { return isSCSrcB16() || isLiteralImm(MVT::bf16); }
+
   bool isSSrc_f16() const { return isSCSrcB16() || isLiteralImm(MVT::f16); }
 
   bool isSSrcV2F16() const {
@@ -540,22 +542,40 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     return isRegOrInlineNoMods(AMDGPU::VS_64RegClassID, MVT::f64);
   }
 
+  bool isVCSrcTBF16() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::bf16);
+  }
+
   bool isVCSrcTF16() const {
     return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::f16);
   }
 
+  bool isVCSrcTBF16_Lo128() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::bf16);
+  }
+
   bool isVCSrcTF16_Lo128() const {
     return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::f16);
   }
 
+  bool isVCSrcFake16BF16_Lo128() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::bf16);
+  }
+
   bool isVCSrcFake16F16_Lo128() const {
     return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::f16);
   }
 
+  bool isVCSrc_bf16() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::bf16);
+  }
+
   bool isVCSrc_f16() const {
     return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::f16);
   }
 
+  bool isVCSrc_v2bf16() const { return isVCSrc_bf16(); }
+
   bool isVCSrc_v2f16() const { return isVCSrc_f16(); }
 
   bool isVSrc_b32() const {
@@ -596,18 +616,34 @@ class AMDGPUOperand : public MCParsedAsmOperand {
 
   bool isVSrc_f64() const { return isVCSrcF64() || isLiteralImm(MVT::f64); }
 
+  bool isVSrcT_bf16() const { return isVCSrcTBF16() || isLiteralImm(MVT::bf16); }
+
   bool isVSrcT_f16() const { return isVCSrcTF16() || isLiteralImm(MVT::f16); }
 
+  bool isVSrcT_bf16_Lo128() const {
+    return isVCSrcTBF16_Lo128() || isLiteralImm(MVT::bf16);
+  }
+
   bool isVSrcT_f16_Lo128() const {
     return isVCSrcTF16_Lo128() || isLiteralImm(MVT::f16);
   }
 
+  bool isVSrcFake16_bf16_Lo128() const {
+    return isVCSrcFake16BF16_Lo128() || isLiteralImm(MVT::bf16);
+  }
+
   bool isVSrcFake16_f16_Lo128() const {
     return isVCSrcFake16F16_Lo128() || isLiteralImm(MVT::f16);
   }
 
+  bool isVSrc_bf16() const { return isVCSrc_bf16() || isLiteralImm(MVT::bf16); }
+
   bool isVSrc_f16() const { return isVCSrc_f16() || isLiteralImm(MVT::f16); }
 
+  bool isVSrc_v2bf16() const {
+    return isVSrc_bf16() || isLiteralImm(MVT::v2bf16);
+  }
+
   bool isVSrc_v2f16() const { return isVSrc_f16() || isLiteralImm(MVT::v2f16); }
 
   bool isVISrcB32() const {
@@ -634,6 +670,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     return isVISrcF16() || isVISrcB32();
   }
 
+  bool isVISrc_64_bf16() const {
+    return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::bf16);
+  }
+
   bool isVISrc_64_f16() const {
     return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::f16);
   }
@@ -802,6 +842,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     return isAISrc_128F16() || isAISrc_128_b32();
   }
 
+  bool isVISrc_128_bf16() const {
+    return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::bf16);
+  }
+
   bool isVISrc_128_f16() const {
     return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::f16);
   }
diff --git a/llvm/lib/Target/AMDGPU/SIDefines.h b/llvm/lib/Target/AMDGPU/SIDefines.h
index 19596d53b45328..66b997eb180613 100644
--- a/llvm/lib/Target/AMDGPU/SIDefines.h
+++ b/llvm/lib/Target/AMDGPU/SIDefines.h
@@ -196,9 +196,12 @@ enum OperandType : unsigned {
   OPERAND_REG_IMM_INT16,
   OPERAND_REG_IMM_FP32,
   OPERAND_REG_IMM_FP64,
+  OPERAND_REG_IMM_BF16,
   OPERAND_REG_IMM_FP16,
+  OPERAND_REG_IMM_BF16_DEFERRED,
   OPERAND_REG_IMM_FP16_DEFERRED,
   OPERAND_REG_IMM_FP32_DEFERRED,
+  OPERAND_REG_IMM_V2BF16,
   OPERAND_REG_IMM_V2FP16,
   OPERAND_REG_IMM_V2INT16,
   OPERAND_REG_IMM_V2INT32,
@@ -208,10 +211,12 @@ enum OperandType : unsigned {
   OPERAND_REG_INLINE_C_INT16,
   OPERAND_REG_INLINE_C_INT32,
   OPERAND_REG_INLINE_C_INT64,
+  OPERAND_REG_INLINE_C_BF16,
   OPERAND_REG_INLINE_C_FP16,
   OPERAND_REG_INLINE_C_FP32,
   OPERAND_REG_INLINE_C_FP64,
   OPERAND_REG_INLINE_C_V2INT16,
+  OPERAND_REG_INLINE_C_V2BF16,
   OPERAND_REG_INLINE_C_V2FP16,
   OPERAND_REG_INLINE_C_V2INT32,
   OPERAND_REG_INLINE_C_V2FP32,
@@ -226,10 +231,12 @@ enum OperandType : unsigned {
   /// Operands with an AccVGPR register or inline constant
   OPERAND_REG_INLINE_AC_INT16,
   OPERAND_REG_INLINE_AC_INT32,
+  OPERAND_REG_INLINE_AC_BF16,
   OPERAND_REG_INLINE_AC_FP16,
   OPERAND_REG_INLINE_AC_FP32,
   OPERAND_REG_INLINE_AC_FP64,
   OPERAND_REG_INLINE_AC_V2INT16,
+  OPERAND_REG_INLINE_AC_V2BF16,
   OPERAND_REG_INLINE_AC_V2FP16,
   OPERAND_REG_INLINE_AC_V2INT32,
   OPERAND_REG_INLINE_AC_V2FP32,
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
index 7edec5a7a5505b..1ef8159c98ea35 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
@@ -1490,20 +1490,17 @@ class getVOPSrc0ForVT<ValueType VT, bit IsTrue16, bit IsFake16 = 1> {
   RegisterOperand ret =
     !if(VT.isFP,
       !if(!eq(VT.Size, 64),
-         VSrc_f64,
-         !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
-            !if(IsTrue16,
-              !if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128),
-              VSrc_f16
-            ),
-            !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
-               VSrc_v2f16,
-               !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
-                 AVSrc_64,
-                 VSrc_f32
+          VSrc_f64,
+          !if(!eq(VT.Value, f16.Value),
+              !if(IsTrue16, !if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128), VSrc_f16),
+              !if(!eq(VT.Value, bf16.Value),
+                 !if(IsTrue16, !if(IsFake16, VSrcFake16_bf16_Lo128, VSrcT_bf16_Lo128), VSrc_bf16),
+                 !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
+                     !if(!eq(VT.Value, v2f16.Value), VSrc_v2f16, VSrc_v2bf16),
+                     !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)), AVSrc_64, VSrc_f32)
+                  )
                )
-            )
-         )
+           )
        ),
        !if(!eq(VT.Size, 64),
           VSrc_b64,
@@ -2513,8 +2510,8 @@ def VOP_V2I16_F32_F32 : VOPProfile <[v2i16, f32, f32, untyped]>;
 def VOP_V2I16_I32_I32 : VOPProfile <[v2i16, i32, i32, untyped]>;
 
 def VOP_F16_V2F16_V2F16_F16 : VOPProfile <[f16, v2f16, v2f16, f16]>;
-def VOP_I16_V2I16_V2I16_I16 : VOPProfile <[i16, v2i16, v2i16, i16]>;
-def VOP_F32_V2I16_V2I16_F32 : VOPProfile <[f32, v2i16, v2i16, f32]>;
+def VOP_BF16_V2BF16_V2BF16_BF16: VOPProfile <[bf16, v2bf16, v2bf16, bf16]>;
+def VOP_F32_V2BF16_V2BF16_F32 : VOPProfile <[f32, v2bf16, v2bf16, f32]>;
 
 def VOP_F32_V2F16_V2F16_V2F16 : VOPProfile <[f32, v2f16, v2f16, v2f16]>;
 
diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
index c9dbe02037ef2e..5c5dd6a4e1a63f 100644
--- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
@@ -1066,7 +1066,7 @@ multiclass AVRegClass<int numRegs, list<ValueType> regTypes,
     // Define the regular class.
     def "" : VRegClassBase<numRegs, regTypes, (add vregList, aregList)>;
 
-    // Define 2-aligned variant 
+    // Define 2-aligned variant
     def _Align2 : VRegClassBase<numRegs, regTypes,
                                 (add (decimate vregList, 2),
                                      (decimate aregList, 2))> {
@@ -1115,6 +1115,7 @@ class RegOrImmOperand <string RegisterClassName, string OperandTypeName,
 //===----------------------------------------------------------------------===//
 
 def SSrc_b16 : RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_INT16", "_Imm16">;
+def SSrc_bf16: RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_BF16", "_Imm16">;
 def SSrc_f16 : RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_FP16", "_Imm16">;
 def SSrc_b32 : RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_INT32", "_Imm32">;
 def SSrc_f32 : RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_FP32", "_Imm32">;
@@ -1142,6 +1143,7 @@ def SCSrc_b64 : RegOrImmOperand <"SReg_64", "OPERAND_REG_INLINE_C_INT64", "_Imm6
 
 // The current and temporary future default used case for VOP3.
 def VSrc_b16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_INT16", "_Imm16">;
+def VSrc_bf16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_BF16", "_Imm16">;
 def VSrc_f16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_FP16", "_Imm16">;
 
 // True16 VOP3 operands.
@@ -1149,6 +1151,10 @@ def VSrcT_b16 : RegOrImmOperand <"VS_16", "OPERAND_REG_IMM_INT16", "_Imm16"> {
   let EncoderMethod = "getMachineOpValueT16";
   let DecoderMethod = "decodeOperand_VSrcT16";
 }
+def VSrcT_bf16 : RegOrImmOperand <"VS_16", "OPERAND_REG_IMM_BF16", "_Imm16"> {
+  let EncoderMethod = "getMachineOpValueT16";
+  let DecoderMethod = "decodeOperand_VSrcT16";
+}
 def VSrcT_f16 : RegOrImmOperand <"VS_16", "OPERAND_REG_IMM_FP16", "_Imm16"> {
   let EncoderMethod = "getMachineOpValueT16";
   let DecoderMethod = "decodeOperand_VSrcT16";
@@ -1159,6 +1165,10 @@ def VSrcT_b16_Lo128 : RegOrImmOperand <"VS_16_Lo128", "OPERAND_REG_IMM_INT16", "
   let EncoderMethod = "getMachineOpValueT16Lo128";
   let DecoderMethod = "decodeOperand_VSrcT16_Lo128";
 }
+def VSrcT_bf16_Lo128 : RegOrImmOperand <"VS_16_Lo128", "OPERAND_REG_IMM_BF16", "_Imm16"> {
+  let EncoderMethod = "getMachineOpValueT16Lo128";
+  let DecoderMethod = "decodeOperand_VSrcT16_Lo128";
+}
 def VSrcT_f16_Lo128 : RegOrImmOperand <"VS_16_Lo128", "OPERAND_REG_IMM_FP16", "_Imm16"> {
   let EncoderMethod = "getMachineOpValueT16Lo128";
   let DecoderMethod = "decodeOperand_VSrcT16_Lo128";
@@ -1167,11 +1177,13 @@ def VSrcT_f16_Lo128 : RegOrImmOperand <"VS_16_Lo128", "OPERAND_REG_IMM_FP16", "_
 // The current and temporary future default used case for fake VOP1/2/C.
 // For VOP1,2,C True16 instructions. _Lo128 use first 128 32-bit VGPRs only.
 def VSrcFake16_b16_Lo128 : RegOrImmOperand <"VS_32_Lo128", "OPERAND_REG_IMM_INT16", "_Imm16">;
+def VSrcFake16_bf16_Lo128 : RegOrImmOperand <"VS_32_Lo128", "OPERAND_REG_IMM_BF16", "_Imm16">;
 def VSrcFake16_f16_Lo128 : RegOrImmOperand <"VS_32_Lo128", "OPERAND_REG_IMM_FP16", "_Imm16">;
 
 def VSrc_b32 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_INT32", "_Imm32">;
 def VSrc_f32 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_FP32", "_Imm32">;
 def VSrc_v2b16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_V2INT16", "_ImmV2I16">;
+def VSrc_v2bf16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_V2BF16", "_ImmV2F16">;
 def VSrc_v2f16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_V2FP16", "_ImmV2F16">;
 def VSrc_b64 : RegOrImmOperand <"VS_64", "OPERAND_REG_IMM_INT64", "_Imm64">;
 def VSrc_f64 : RegOrImmOperand <"VS_64", "OPERAND_REG_IMM_FP64", "_Imm64"> {
@@ -1185,9 +1197,13 @@ def VSrc_v2f32 : RegOrImmOperand <"VS_64", "OPERAND_REG_IMM_V2FP32", "_Imm32">;
 //  with FMAMK/FMAAK
 //===----------------------------------------------------------------------===//
 
+def VSrc_bf16_Deferred : RegOrImmOperand<"VS_32", "OPERAND_REG_IMM_BF16_DEFERRED", "_Deferred_Imm16">;
 def VSrc_f16_Deferred : RegOrImmOperand<"VS_32", "OPERAND_REG_IMM_FP16_DEFERRED", "_Deferred_Imm16">;
 def VSrc_f32_Deferred : RegOrImmOperand<"VS_32", "OPERAND_REG_IMM_FP32_DEFERRED", "_Deferred_Imm32">;
 
+def VSrcFake16_bf16_Lo128_Deferred : RegOrImmOperand<"VS_32_Lo128",
+                                                    "OPERAND_REG_IMM_BF16_DEFERRED",
+                                                    "_Deferred_Imm16">;
 def VSrcFake16_f16_Lo128_Deferred : RegOrImmOperand<"VS_32_Lo128",
                                                     "OPERAND_REG_IMM_FP16_DEFERRED",
                                                     "_Deferred_Imm16">;
@@ -1252,19 +1268,23 @@ def ARegSrc_32 : AVOperand<AGPR_32, "decodeSrcA9", "OPW32">;
 //===----------------------------------------------------------------------===//
 
 def VCSrc_b16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_INT16", "_Imm16">;
+def VCSrc_bf16: RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_BF16", "_Imm16">;
 def VCSrc_f16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_FP16", "_Imm16">;
 def VCSrc_b32 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_INT32", "_Imm32">;
 def VCSrc_f32 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_FP32", "_Imm32">;
 def VCSrc_v2b16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_V2INT16", "_ImmV2I16">;
+def VCSrc_v2bf16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_V2BF16", "_ImmV2F16">;
 def VCSrc_v2f16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_V2FP16", "_ImmV2F16">;
 
 //===----------------------------------------------------------------------===//
 //  VISrc_* Operands with a VGPR or an inline constant
 //===----------------------------------------------------------------------===//
 
+def VISrc_64_bf16 : RegOrImmOperand <"VReg_64", "OPERAND_REG_INLINE_C_BF16", "_Imm16">;
 def VISrc_64_f16 : RegOrImmOperand <"VReg_64", "OPERAND_REG_INLINE_C_FP16", "_Imm16">;
 def VISrc_64_b32 : RegOrImmOperand <"VReg_64", "OPERAND_REG_INLINE_C_INT32", "_Imm32">;
 def VISrc_64_f64 : RegOrImmOperand <"VReg_64", "OPERAND_REG_INLINE_C_FP64", "_Imm64">;
+def VISrc_128_bf16 : RegOrImmOperand <"VReg_128", "OPERAND_REG_INLINE_C_BF16", "_Imm16">;
 def VISrc_128_f16 : RegOrImmOperand <"VReg_128", "OPERAND_REG_INLINE_C_FP16", "_Imm16">;
 def VISrc_128_b32 : RegOrImmOperand <"VReg_128", "OPERAND_REG_INLINE_C_INT32", "_Imm32">;
 def VISrc_128_f32 : RegOrImmOperand <"VReg_128", "OPERAND_REG_INLINE_C_FP32", "_Imm32">;
diff --git a/llvm/lib/Target/AMDGPU/VOP3Instructions.td b/llvm/lib/Target/AMDGPU/VOP3Instructions.td
index 8d965d3b9041d5..35cffa22f45929 100644
--- a/llvm/lib/Target/AMDGPU/VOP3Instructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3Instructions.td
@@ -904,7 +904,7 @@ let SubtargetPredicate = isGFX12Plus, ReadsModeReg = 0 in {
 
 let SubtargetPredicate = HasDot9Insts, IsDOT=1 in {
   defm V_DOT2_F16_F16 :   VOP3Inst<"v_dot2_f16_f16",   VOP3_DOT_Profile<VOP_F16_V2F16_V2F16_F16>, int_amdgcn_fdot2_f16_f16>;
-  defm V_DOT2_BF16_BF16 : VOP3Inst<"v_dot2_bf16_bf16", VOP3_DOT_Profile<VOP_I16_V2I16_V2I16_I16>, int_amdgcn_fdot2_bf16_bf16>;
+  defm V_DOT2_BF16_BF16 : VOP3Inst<"v_dot2_bf16_bf16", VOP3_DOT_Profile<VOP_BF16_V2BF16_V2BF16_BF16>, int_amdgcn_fdot2_bf16_bf16>;
 }
 
 class VOP_Pseudo_Scalar<RegisterClass Dst, RegisterOperand SrcOp,
diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index ef14a587c42e79..5686b2e539e31d 100644
--- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -396,7 +396,7 @@ defm V_DOT8_I32_I4  : VOP3PInst<"v_dot8_i32_i4",
 } // End OtherPredicates = [HasDot1Insts]
 
 def DOT2_BF16_Profile
-  : VOP3P_Profile<VOP_F32_V2I16_V2I16_F32, VOP3_REGULAR, /*HasDPP*/ 1> {
+  : VOP3P_Profile<VOP_F32_V2BF16_V2BF16_F32, VOP3_REGULAR, /*HasDPP*/ 1> {
   let HasSrc1Mods = 1;
 }
 



More information about the llvm-commits mailing list