[clang] [llvm] [RFC][WIP][AMDGPU] Use `bf16` instead of `i16` for bfloat (PR #80908)
Shilei Tian via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 7 20:09:43 PST 2024
https://github.com/shiltian updated https://github.com/llvm/llvm-project/pull/80908
>From 6a2bacee940d95abc53bcff2332b0d9aa0f1073f Mon Sep 17 00:00:00 2001
From: Shilei Tian <i at tianshilei.me>
Date: Wed, 7 Feb 2024 23:09:33 -0500
Subject: [PATCH] [RFC][WIP][AMDGPU] Use `bf16` instead of `i16` for bfloat
Currently it looks like we generally use `i16` to represent `bf16` in those tablegen
files. I'm not sure of the reason behind it. My wild guess is the type `bf16` was
not available when we enabled the support. This patch is trying to use `bf16`
directly in those tablegen files, aiming at fixing #79369. Of course for #79369
a workaround can be to treat all `INT16` variants as `BFloat` in `getOpFltSemantics`,
but it doesn't look good IMHO.
Since I'm fairly new to AMDGPU backend, I'd appreciate it if you can point out
where I don't understand correctly.
---
clang/include/clang/Basic/BuiltinsAMDGPU.def | 4 +-
.../builtins-amdgcn-dl-insts-err.cl | 9 ++-
.../builtins-amdgcn-dl-insts-gfx11.cl | 13 ++--
llvm/include/llvm/IR/IntrinsicsAMDGPU.td | 12 ++--
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 5 +-
.../AMDGPU/AsmParser/AMDGPUAsmParser.cpp | 66 +++++++++++++++++++
.../AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp | 10 +++
.../MCTargetDesc/AMDGPUMCCodeEmitter.cpp | 7 ++
llvm/lib/Target/AMDGPU/SIDefines.h | 7 ++
llvm/lib/Target/AMDGPU/SIInstrInfo.cpp | 7 ++
llvm/lib/Target/AMDGPU/SIInstrInfo.td | 60 +++++++++--------
llvm/lib/Target/AMDGPU/SIRegisterInfo.td | 22 ++++++-
llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h | 7 ++
llvm/lib/Target/AMDGPU/VOP3Instructions.td | 2 +-
llvm/lib/Target/AMDGPU/VOP3PInstructions.td | 2 +-
.../AMDGPU/llvm.amdgcn.fdot2.bf16.bf16.ll | 36 +++++-----
.../AMDGPU/llvm.amdgcn.fdot2.f32.bf16.ll | 14 ++--
llvm/test/MC/AMDGPU/bf16_imm.s | 8 +++
18 files changed, 217 insertions(+), 74 deletions(-)
create mode 100644 llvm/test/MC/AMDGPU/bf16_imm.s
diff --git a/clang/include/clang/Basic/BuiltinsAMDGPU.def b/clang/include/clang/Basic/BuiltinsAMDGPU.def
index 213311b96df74f..4fe236e8aca12d 100644
--- a/clang/include/clang/Basic/BuiltinsAMDGPU.def
+++ b/clang/include/clang/Basic/BuiltinsAMDGPU.def
@@ -246,8 +246,8 @@ TARGET_BUILTIN(__builtin_amdgcn_ds_atomic_fadd_v2f16, "V2hV2h*3V2h", "t", "atomi
TARGET_BUILTIN(__builtin_amdgcn_fdot2, "fV2hV2hfIb", "nc", "dot10-insts")
TARGET_BUILTIN(__builtin_amdgcn_fdot2_f16_f16, "hV2hV2hh", "nc", "dot9-insts")
-TARGET_BUILTIN(__builtin_amdgcn_fdot2_bf16_bf16, "sV2sV2ss", "nc", "dot9-insts")
-TARGET_BUILTIN(__builtin_amdgcn_fdot2_f32_bf16, "fV2sV2sfIb", "nc", "dot9-insts")
+TARGET_BUILTIN(__builtin_amdgcn_fdot2_bf16_bf16, "yV2yV2yy", "nc", "dot9-insts")
+TARGET_BUILTIN(__builtin_amdgcn_fdot2_f32_bf16, "fV2yV2yfIb", "nc", "dot9-insts")
TARGET_BUILTIN(__builtin_amdgcn_sdot2, "SiV2SsV2SsSiIb", "nc", "dot2-insts")
TARGET_BUILTIN(__builtin_amdgcn_udot2, "UiV2UsV2UsUiIb", "nc", "dot2-insts")
TARGET_BUILTIN(__builtin_amdgcn_sdot4, "SiSiSiSiIb", "nc", "dot1-insts")
diff --git a/clang/test/CodeGenOpenCL/builtins-amdgcn-dl-insts-err.cl b/clang/test/CodeGenOpenCL/builtins-amdgcn-dl-insts-err.cl
index f5317683d0ff97..fa225c4962c90b 100644
--- a/clang/test/CodeGenOpenCL/builtins-amdgcn-dl-insts-err.cl
+++ b/clang/test/CodeGenOpenCL/builtins-amdgcn-dl-insts-err.cl
@@ -5,6 +5,8 @@
typedef unsigned int uint;
typedef half __attribute__((ext_vector_type(2))) half2;
+typedef __bf16 bfloat;
+typedef bfloat __attribute__((ext_vector_type(2))) bfloat2;
typedef short __attribute__((ext_vector_type(2))) short2;
typedef unsigned short __attribute__((ext_vector_type(2))) ushort2;
@@ -15,16 +17,17 @@ kernel void builtins_amdgcn_dl_insts_err(
half2 v2hA, half2 v2hB, float fC, half hC,
short2 v2ssA, short2 v2ssB, short sC, int siA, int siB, int siC,
ushort2 v2usA, ushort2 v2usB, uint uiA, uint uiB, uint uiC,
+ bfloat2 v2bfsA, bfloat2 v2bfsB, bfloat bfC,
int A, int B, int C) {
fOut[0] = __builtin_amdgcn_fdot2(v2hA, v2hB, fC, false); // expected-error {{'__builtin_amdgcn_fdot2' needs target feature dot10-insts}}
fOut[1] = __builtin_amdgcn_fdot2(v2hA, v2hB, fC, true); // expected-error {{'__builtin_amdgcn_fdot2' needs target feature dot10-insts}}
hOut[0] = __builtin_amdgcn_fdot2_f16_f16(v2hA, v2hB, hC); // expected-error {{'__builtin_amdgcn_fdot2_f16_f16' needs target feature dot9-insts}}
- sOut[0] = __builtin_amdgcn_fdot2_bf16_bf16(v2ssA, v2ssB, sC); // expected-error {{'__builtin_amdgcn_fdot2_bf16_bf16' needs target feature dot9-insts}}
+ sOut[0] = __builtin_amdgcn_fdot2_bf16_bf16(v2bfsA, v2bfsB, bfC); // expected-error {{'__builtin_amdgcn_fdot2_bf16_bf16' needs target feature dot9-insts}}
- fOut[3] = __builtin_amdgcn_fdot2_f32_bf16(v2ssA, v2ssB, fC, false); // expected-error {{'__builtin_amdgcn_fdot2_f32_bf16' needs target feature dot9-insts}}
- fOut[4] = __builtin_amdgcn_fdot2_f32_bf16(v2ssA, v2ssB, fC, true); // expected-error {{'__builtin_amdgcn_fdot2_f32_bf16' needs target feature dot9-insts}}
+ fOut[3] = __builtin_amdgcn_fdot2_f32_bf16(v2bfsA, v2bfsB, fC, false); // expected-error {{'__builtin_amdgcn_fdot2_f32_bf16' needs target feature dot9-insts}}
+ fOut[4] = __builtin_amdgcn_fdot2_f32_bf16(v2bfsA, v2bfsB, fC, true); // expected-error {{'__builtin_amdgcn_fdot2_f32_bf16' needs target feature dot9-insts}}
siOut[0] = __builtin_amdgcn_sdot2(v2ssA, v2ssB, siC, false); // expected-error {{'__builtin_amdgcn_sdot2' needs target feature dot2-insts}}
siOut[1] = __builtin_amdgcn_sdot2(v2ssA, v2ssB, siC, true); // expected-error {{'__builtin_amdgcn_sdot2' needs target feature dot2-insts}}
diff --git a/clang/test/CodeGenOpenCL/builtins-amdgcn-dl-insts-gfx11.cl b/clang/test/CodeGenOpenCL/builtins-amdgcn-dl-insts-gfx11.cl
index dc7069decaaa61..cfd96f5ac768b7 100644
--- a/clang/test/CodeGenOpenCL/builtins-amdgcn-dl-insts-gfx11.cl
+++ b/clang/test/CodeGenOpenCL/builtins-amdgcn-dl-insts-gfx11.cl
@@ -4,16 +4,17 @@
typedef unsigned int uint;
typedef half __attribute__((ext_vector_type(2))) half2;
-typedef short __attribute__((ext_vector_type(2))) short2;
+typedef __bf16 bfloat;
+typedef bfloat __attribute__((ext_vector_type(2))) bfloat2;
typedef unsigned short __attribute__((ext_vector_type(2))) ushort2;
// CHECK-LABEL: @builtins_amdgcn_dl_insts
// CHECK: call float @llvm.amdgcn.fdot2(<2 x half> %v2hA, <2 x half> %v2hB, float %fC, i1 false)
// CHECK: call float @llvm.amdgcn.fdot2(<2 x half> %v2hA, <2 x half> %v2hB, float %fC, i1 true)
// CHECK: call half @llvm.amdgcn.fdot2.f16.f16(<2 x half> %v2hA, <2 x half> %v2hB, half %hC)
-// CHECK: call i16 @llvm.amdgcn.fdot2.bf16.bf16(<2 x i16> %v2ssA, <2 x i16> %v2ssB, i16 %sC)
-// CHECK: call float @llvm.amdgcn.fdot2.f32.bf16(<2 x i16> %v2ssA, <2 x i16> %v2ssB, float %fC, i1 false)
-// CHECK: call float @llvm.amdgcn.fdot2.f32.bf16(<2 x i16> %v2ssA, <2 x i16> %v2ssB, float %fC, i1 true)
+// CHECK: call bfloat @llvm.amdgcn.fdot2.bf16.bf16(<2 x bfloat> %v2ssA, <2 x bfloat> %v2ssB, bfloat %sC)
+// CHECK: call float @llvm.amdgcn.fdot2.f32.bf16(<2 x bfloat> %v2ssA, <2 x bfloat> %v2ssB, float %fC, i1 false)
+// CHECK: call float @llvm.amdgcn.fdot2.f32.bf16(<2 x bfloat> %v2ssA, <2 x bfloat> %v2ssB, float %fC, i1 true)
// CHECK: call i32 @llvm.amdgcn.udot4(i32 %uiA, i32 %uiB, i32 %uiC, i1 false)
// CHECK: call i32 @llvm.amdgcn.udot4(i32 %uiA, i32 %uiB, i32 %uiC, i1 true)
// CHECK: call i32 @llvm.amdgcn.sudot4(i1 true, i32 %A, i1 false, i32 %B, i32 %C, i1 false)
@@ -25,9 +26,9 @@ typedef unsigned short __attribute__((ext_vector_type(2))) ushort2;
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
kernel void builtins_amdgcn_dl_insts_err(
global float *fOut, global int *siOut, global uint *uiOut,
- global short *sOut, global int *iOut, global half *hOut,
+ global bfloat *sOut, global int *iOut, global half *hOut,
half2 v2hA, half2 v2hB, float fC, half hC,
- short2 v2ssA, short2 v2ssB, short sC, int siA, int siB, int siC,
+ bfloat2 v2ssA, bfloat2 v2ssB, bfloat sC, int siA, int siB, int siC,
ushort2 v2usA, ushort2 v2usB, uint uiA, uint uiB, uint uiC,
int A, int B, int C) {
fOut[0] = __builtin_amdgcn_fdot2(v2hA, v2hB, fC, false);
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index 202fa4e8f4ea81..0f29653f1f5bec 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -2819,11 +2819,11 @@ def int_amdgcn_fdot2_f16_f16 :
def int_amdgcn_fdot2_bf16_bf16 :
ClangBuiltin<"__builtin_amdgcn_fdot2_bf16_bf16">,
DefaultAttrsIntrinsic<
- [llvm_i16_ty], // %r
+ [llvm_bfloat_ty], // %r
[
- llvm_v2i16_ty, // %a
- llvm_v2i16_ty, // %b
- llvm_i16_ty // %c
+ llvm_v2bf16_ty, // %a
+ llvm_v2bf16_ty, // %b
+ llvm_bfloat_ty // %c
],
[IntrNoMem, IntrSpeculatable]
>;
@@ -2835,8 +2835,8 @@ def int_amdgcn_fdot2_f32_bf16 :
DefaultAttrsIntrinsic<
[llvm_float_ty], // %r
[
- llvm_v2i16_ty, // %a
- llvm_v2i16_ty, // %b
+ llvm_v2bf16_ty, // %a
+ llvm_v2bf16_ty, // %b
llvm_float_ty, // %c
llvm_i1_ty // %clamp
],
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index dd38317c26bff6..a1c638d931b7f8 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -1562,8 +1562,9 @@ bool IRTranslator::translateBitCast(const User &U,
bool IRTranslator::translateCast(unsigned Opcode, const User &U,
MachineIRBuilder &MIRBuilder) {
- if (U.getType()->getScalarType()->isBFloatTy() ||
- U.getOperand(0)->getType()->getScalarType()->isBFloatTy())
+ if (Opcode != TargetOpcode::G_BITCAST &&
+ (U.getType()->getScalarType()->isBFloatTy() ||
+ U.getOperand(0)->getType()->getScalarType()->isBFloatTy()))
return false;
Register Op = getOrCreateVReg(*U.getOperand(0));
Register Res = getOrCreateVReg(U);
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index 225e781588668f..787217171721d8 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -474,6 +474,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
bool isSSrcF64() const { return isSCSrc_b64() || isLiteralImm(MVT::f64); }
+ bool isSSrc_bf16() const { return isSCSrcB16() || isLiteralImm(MVT::bf16); }
+
bool isSSrc_f16() const { return isSCSrcB16() || isLiteralImm(MVT::f16); }
bool isSSrcV2F16() const {
@@ -540,22 +542,40 @@ class AMDGPUOperand : public MCParsedAsmOperand {
return isRegOrInlineNoMods(AMDGPU::VS_64RegClassID, MVT::f64);
}
+ bool isVCSrcTBF16() const {
+ return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::bf16);
+ }
+
bool isVCSrcTF16() const {
return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::f16);
}
+ bool isVCSrcTBF16_Lo128() const {
+ return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::bf16);
+ }
+
bool isVCSrcTF16_Lo128() const {
return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::f16);
}
+ bool isVCSrcFake16BF16_Lo128() const {
+ return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::bf16);
+ }
+
bool isVCSrcFake16F16_Lo128() const {
return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::f16);
}
+ bool isVCSrc_bf16() const {
+ return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::bf16);
+ }
+
bool isVCSrc_f16() const {
return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::f16);
}
+ bool isVCSrc_v2bf16() const { return isVCSrc_bf16(); }
+
bool isVCSrc_v2f16() const { return isVCSrc_f16(); }
bool isVSrc_b32() const {
@@ -596,18 +616,34 @@ class AMDGPUOperand : public MCParsedAsmOperand {
bool isVSrc_f64() const { return isVCSrcF64() || isLiteralImm(MVT::f64); }
+ bool isVSrcT_bf16() const { return isVCSrcTBF16() || isLiteralImm(MVT::bf16); }
+
bool isVSrcT_f16() const { return isVCSrcTF16() || isLiteralImm(MVT::f16); }
+ bool isVSrcT_bf16_Lo128() const {
+ return isVCSrcTBF16_Lo128() || isLiteralImm(MVT::bf16);
+ }
+
bool isVSrcT_f16_Lo128() const {
return isVCSrcTF16_Lo128() || isLiteralImm(MVT::f16);
}
+ bool isVSrcFake16_bf16_Lo128() const {
+ return isVCSrcFake16BF16_Lo128() || isLiteralImm(MVT::bf16);
+ }
+
bool isVSrcFake16_f16_Lo128() const {
return isVCSrcFake16F16_Lo128() || isLiteralImm(MVT::f16);
}
+ bool isVSrc_bf16() const { return isVCSrc_bf16() || isLiteralImm(MVT::bf16); }
+
bool isVSrc_f16() const { return isVCSrc_f16() || isLiteralImm(MVT::f16); }
+ bool isVSrc_v2bf16() const {
+ return isVSrc_bf16() || isLiteralImm(MVT::v2bf16);
+ }
+
bool isVSrc_v2f16() const { return isVSrc_f16() || isLiteralImm(MVT::v2f16); }
bool isVISrcB32() const {
@@ -634,6 +670,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
return isVISrcF16() || isVISrcB32();
}
+ bool isVISrc_64_bf16() const {
+ return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::bf16);
+ }
+
bool isVISrc_64_f16() const {
return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::f16);
}
@@ -802,6 +842,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
return isAISrc_128F16() || isAISrc_128_b32();
}
+ bool isVISrc_128_bf16() const {
+ return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::bf16);
+ }
+
bool isVISrc_128_f16() const {
return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::f16);
}
@@ -1889,6 +1933,14 @@ static const fltSemantics *getOpFltSemantics(uint8_t OperandType) {
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_KIMM16:
return &APFloat::IEEEhalf();
+ case AMDGPU::OPERAND_REG_IMM_BF16:
+ case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
+ case AMDGPU::OPERAND_REG_INLINE_C_BF16:
+ case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
+ case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
+ case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
+ case AMDGPU::OPERAND_REG_IMM_V2BF16:
+ return &APFloat::BFloat();
default:
llvm_unreachable("unsupported fp type");
}
@@ -2185,17 +2237,24 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
case AMDGPU::OPERAND_REG_IMM_INT16:
+ case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_FP16:
+ case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
+ case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
+ case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
+ case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
+ case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
case AMDGPU::OPERAND_REG_IMM_V2INT16:
+ case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP32:
case AMDGPU::OPERAND_REG_IMM_V2FP32:
@@ -2239,6 +2298,7 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
case AMDGPU::OPERAND_REG_IMM_V2INT16:
+ case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_IMM_V2FP32:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP32:
@@ -2276,11 +2336,15 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
return;
case AMDGPU::OPERAND_REG_IMM_INT16:
+ case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_FP16:
+ case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
+ case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
+ case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
if (isSafeTruncation(Val, 16) &&
AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
@@ -2295,8 +2359,10 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
return;
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
+ case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
+ case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16: {
assert(isSafeTruncation(Val, 16));
assert(AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
index abfa4a3531e8e1..96a0168f37e405 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
@@ -521,8 +521,11 @@ void AMDGPUInstPrinter::printImmediateV216(uint32_t Imm, uint8_t OpType,
if (printImmediateFloat32(Imm, STI, O))
return;
break;
+ case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
+ case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
+ case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
if (isUInt<16>(Imm) &&
printImmediateFloat16(static_cast<uint16_t>(Imm), STI, O))
@@ -792,17 +795,24 @@ void AMDGPUInstPrinter::printRegularOperand(const MCInst *MI, unsigned OpNo,
case AMDGPU::OPERAND_REG_IMM_INT16:
printImmediateInt16(Op.getImm(), STI, O);
break;
+ case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
+ case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
+ case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_FP16:
+ case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
printImmediate16(Op.getImm(), STI, O);
break;
case AMDGPU::OPERAND_REG_IMM_V2INT16:
+ case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
+ case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
+ case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
printImmediateV216(Op.getImm(), OpTy, STI, O);
break;
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
index 11f5e456e8d348..9ec174ba56c242 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
@@ -276,9 +276,13 @@ AMDGPUMCCodeEmitter::getLitEncoding(const MCOperand &MO,
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
return getLit16IntEncoding(static_cast<uint16_t>(Imm), STI);
+ case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_FP16:
+ case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
+ case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
+ case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
// FIXME Is this correct? What do inline immediates do on SI for f16 src
// which does not have f16 support?
@@ -288,8 +292,11 @@ AMDGPUMCCodeEmitter::getLitEncoding(const MCOperand &MO,
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
return AMDGPU::getInlineEncodingV2I16(static_cast<uint32_t>(Imm))
.value_or(255);
+ case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
+ case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
+ case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
return AMDGPU::getInlineEncodingV2F16(static_cast<uint32_t>(Imm))
.value_or(255);
diff --git a/llvm/lib/Target/AMDGPU/SIDefines.h b/llvm/lib/Target/AMDGPU/SIDefines.h
index 19596d53b45328..66b997eb180613 100644
--- a/llvm/lib/Target/AMDGPU/SIDefines.h
+++ b/llvm/lib/Target/AMDGPU/SIDefines.h
@@ -196,9 +196,12 @@ enum OperandType : unsigned {
OPERAND_REG_IMM_INT16,
OPERAND_REG_IMM_FP32,
OPERAND_REG_IMM_FP64,
+ OPERAND_REG_IMM_BF16,
OPERAND_REG_IMM_FP16,
+ OPERAND_REG_IMM_BF16_DEFERRED,
OPERAND_REG_IMM_FP16_DEFERRED,
OPERAND_REG_IMM_FP32_DEFERRED,
+ OPERAND_REG_IMM_V2BF16,
OPERAND_REG_IMM_V2FP16,
OPERAND_REG_IMM_V2INT16,
OPERAND_REG_IMM_V2INT32,
@@ -208,10 +211,12 @@ enum OperandType : unsigned {
OPERAND_REG_INLINE_C_INT16,
OPERAND_REG_INLINE_C_INT32,
OPERAND_REG_INLINE_C_INT64,
+ OPERAND_REG_INLINE_C_BF16,
OPERAND_REG_INLINE_C_FP16,
OPERAND_REG_INLINE_C_FP32,
OPERAND_REG_INLINE_C_FP64,
OPERAND_REG_INLINE_C_V2INT16,
+ OPERAND_REG_INLINE_C_V2BF16,
OPERAND_REG_INLINE_C_V2FP16,
OPERAND_REG_INLINE_C_V2INT32,
OPERAND_REG_INLINE_C_V2FP32,
@@ -226,10 +231,12 @@ enum OperandType : unsigned {
/// Operands with an AccVGPR register or inline constant
OPERAND_REG_INLINE_AC_INT16,
OPERAND_REG_INLINE_AC_INT32,
+ OPERAND_REG_INLINE_AC_BF16,
OPERAND_REG_INLINE_AC_FP16,
OPERAND_REG_INLINE_AC_FP32,
OPERAND_REG_INLINE_AC_FP64,
OPERAND_REG_INLINE_AC_V2INT16,
+ OPERAND_REG_INLINE_AC_V2BF16,
OPERAND_REG_INLINE_AC_V2FP16,
OPERAND_REG_INLINE_AC_V2INT32,
OPERAND_REG_INLINE_AC_V2FP32,
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
index c7628bd354309c..fcb2a6f1f3d75d 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
@@ -4181,13 +4181,20 @@ bool SIInstrInfo::isInlineConstant(const MachineOperand &MO,
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
return AMDGPU::isInlinableLiteralV2I16(Imm);
+ case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
+ case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
+ case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
return AMDGPU::isInlinableLiteralV2F16(Imm);
+ case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_FP16:
+ case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
+ case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
+ case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16: {
if (isInt<16>(Imm) || isUInt<16>(Imm)) {
// A few special case instructions have 16-bit operands on subtargets
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
index 7edec5a7a5505b..056ef1403d9685 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
@@ -1490,20 +1490,17 @@ class getVOPSrc0ForVT<ValueType VT, bit IsTrue16, bit IsFake16 = 1> {
RegisterOperand ret =
!if(VT.isFP,
!if(!eq(VT.Size, 64),
- VSrc_f64,
- !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
- !if(IsTrue16,
- !if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128),
- VSrc_f16
- ),
- !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
- VSrc_v2f16,
- !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
- AVSrc_64,
- VSrc_f32
+ VSrc_f64,
+ !if(!eq(VT.Value, f16.Value),
+ !if(IsTrue16, !if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128), VSrc_f16),
+ !if(!eq(VT.Value, bf16.Value),
+ !if(IsTrue16, !if(IsFake16, VSrcFake16_bf16_Lo128, VSrcT_bf16_Lo128), VSrc_bf16),
+ !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
+ !if(!eq(VT.Value, v2f16.Value), VSrc_v2f16, VSrc_v2bf16),
+ !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)), AVSrc_64, VSrc_f32)
+ )
)
- )
- )
+ )
),
!if(!eq(VT.Size, 64),
VSrc_b64,
@@ -1562,16 +1559,20 @@ class getVOP3SrcForVT<ValueType VT, bit IsTrue16 = 0> {
!if(!eq(VT.Value, i1.Value),
SSrc_i1,
!if(VT.isFP,
- !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
- !if(IsTrue16, VSrcT_f16, VSrc_f16),
- !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
- VSrc_v2f16,
- !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
- AVSrc_64,
- VSrc_f32
- )
- )
- ),
+ !if(!eq(VT.Value, f16.Value),
+ !if(IsTrue16, VSrcT_f16, VSrc_f16),
+ !if(!eq(VT.Value, bf16.Value),
+ !if(IsTrue16, VSrcT_bf16, VSrc_bf16),
+ !if(!eq(VT.Value, v2f16.Value),
+ VSrc_v2f16,
+ !if(!eq(VT.Value, v2bf16.Value),
+ VSrc_v2bf16,
+ !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
+ AVSrc_64, VSrc_f32)
+ )
+ )
+ )
+ ),
!if(!eq(VT.Value, i16.Value),
!if(IsTrue16, VSrcT_b16, VSrc_b16),
!if(!eq(VT.Value, v2i16.Value),
@@ -1590,8 +1591,13 @@ class getVOP3DPPSrcForVT<ValueType VT> {
RegisterOperand ret =
!if (!eq(VT.Value, i1.Value), SSrc_i1,
!if (VT.isFP,
- !if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)), VCSrc_f16,
- !if (!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)), VCSrc_v2f16, VCSrc_f32)),
+ !if(!eq(VT.Value, f16.Value), VCSrc_f16,
+ !if(!eq(VT.Value, bf16.Value), VCSrc_bf16,
+ !if(!eq(VT.Value, v2f16.Value), VCSrc_v2f16,
+ !if(!eq(VT.Value, v2bf16.Value), VCSrc_v2bf16, VCSrc_f32)
+ )
+ )
+ ),
!if (!eq(VT.Value, i16.Value), VCSrc_b16,
!if (!eq(VT.Value, v2i16.Value), VCSrc_v2b16,
VCSrc_b32))));
@@ -2513,8 +2519,8 @@ def VOP_V2I16_F32_F32 : VOPProfile <[v2i16, f32, f32, untyped]>;
def VOP_V2I16_I32_I32 : VOPProfile <[v2i16, i32, i32, untyped]>;
def VOP_F16_V2F16_V2F16_F16 : VOPProfile <[f16, v2f16, v2f16, f16]>;
-def VOP_I16_V2I16_V2I16_I16 : VOPProfile <[i16, v2i16, v2i16, i16]>;
-def VOP_F32_V2I16_V2I16_F32 : VOPProfile <[f32, v2i16, v2i16, f32]>;
+def VOP_BF16_V2BF16_V2BF16_BF16: VOPProfile <[bf16, v2bf16, v2bf16, bf16]>;
+def VOP_F32_V2BF16_V2BF16_F32 : VOPProfile <[f32, v2bf16, v2bf16, f32]>;
def VOP_F32_V2F16_V2F16_V2F16 : VOPProfile <[f32, v2f16, v2f16, v2f16]>;
diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
index c9dbe02037ef2e..5c5dd6a4e1a63f 100644
--- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
@@ -1066,7 +1066,7 @@ multiclass AVRegClass<int numRegs, list<ValueType> regTypes,
// Define the regular class.
def "" : VRegClassBase<numRegs, regTypes, (add vregList, aregList)>;
- // Define 2-aligned variant
+ // Define 2-aligned variant
def _Align2 : VRegClassBase<numRegs, regTypes,
(add (decimate vregList, 2),
(decimate aregList, 2))> {
@@ -1115,6 +1115,7 @@ class RegOrImmOperand <string RegisterClassName, string OperandTypeName,
//===----------------------------------------------------------------------===//
def SSrc_b16 : RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_INT16", "_Imm16">;
+def SSrc_bf16: RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_BF16", "_Imm16">;
def SSrc_f16 : RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_FP16", "_Imm16">;
def SSrc_b32 : RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_INT32", "_Imm32">;
def SSrc_f32 : RegOrImmOperand <"SReg_32", "OPERAND_REG_IMM_FP32", "_Imm32">;
@@ -1142,6 +1143,7 @@ def SCSrc_b64 : RegOrImmOperand <"SReg_64", "OPERAND_REG_INLINE_C_INT64", "_Imm6
// The current and temporary future default used case for VOP3.
def VSrc_b16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_INT16", "_Imm16">;
+def VSrc_bf16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_BF16", "_Imm16">;
def VSrc_f16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_FP16", "_Imm16">;
// True16 VOP3 operands.
@@ -1149,6 +1151,10 @@ def VSrcT_b16 : RegOrImmOperand <"VS_16", "OPERAND_REG_IMM_INT16", "_Imm16"> {
let EncoderMethod = "getMachineOpValueT16";
let DecoderMethod = "decodeOperand_VSrcT16";
}
+def VSrcT_bf16 : RegOrImmOperand <"VS_16", "OPERAND_REG_IMM_BF16", "_Imm16"> {
+ let EncoderMethod = "getMachineOpValueT16";
+ let DecoderMethod = "decodeOperand_VSrcT16";
+}
def VSrcT_f16 : RegOrImmOperand <"VS_16", "OPERAND_REG_IMM_FP16", "_Imm16"> {
let EncoderMethod = "getMachineOpValueT16";
let DecoderMethod = "decodeOperand_VSrcT16";
@@ -1159,6 +1165,10 @@ def VSrcT_b16_Lo128 : RegOrImmOperand <"VS_16_Lo128", "OPERAND_REG_IMM_INT16", "
let EncoderMethod = "getMachineOpValueT16Lo128";
let DecoderMethod = "decodeOperand_VSrcT16_Lo128";
}
+def VSrcT_bf16_Lo128 : RegOrImmOperand <"VS_16_Lo128", "OPERAND_REG_IMM_BF16", "_Imm16"> {
+ let EncoderMethod = "getMachineOpValueT16Lo128";
+ let DecoderMethod = "decodeOperand_VSrcT16_Lo128";
+}
def VSrcT_f16_Lo128 : RegOrImmOperand <"VS_16_Lo128", "OPERAND_REG_IMM_FP16", "_Imm16"> {
let EncoderMethod = "getMachineOpValueT16Lo128";
let DecoderMethod = "decodeOperand_VSrcT16_Lo128";
@@ -1167,11 +1177,13 @@ def VSrcT_f16_Lo128 : RegOrImmOperand <"VS_16_Lo128", "OPERAND_REG_IMM_FP16", "_
// The current and temporary future default used case for fake VOP1/2/C.
// For VOP1,2,C True16 instructions. _Lo128 use first 128 32-bit VGPRs only.
def VSrcFake16_b16_Lo128 : RegOrImmOperand <"VS_32_Lo128", "OPERAND_REG_IMM_INT16", "_Imm16">;
+def VSrcFake16_bf16_Lo128 : RegOrImmOperand <"VS_32_Lo128", "OPERAND_REG_IMM_BF16", "_Imm16">;
def VSrcFake16_f16_Lo128 : RegOrImmOperand <"VS_32_Lo128", "OPERAND_REG_IMM_FP16", "_Imm16">;
def VSrc_b32 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_INT32", "_Imm32">;
def VSrc_f32 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_FP32", "_Imm32">;
def VSrc_v2b16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_V2INT16", "_ImmV2I16">;
+def VSrc_v2bf16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_V2BF16", "_ImmV2F16">;
def VSrc_v2f16 : RegOrImmOperand <"VS_32", "OPERAND_REG_IMM_V2FP16", "_ImmV2F16">;
def VSrc_b64 : RegOrImmOperand <"VS_64", "OPERAND_REG_IMM_INT64", "_Imm64">;
def VSrc_f64 : RegOrImmOperand <"VS_64", "OPERAND_REG_IMM_FP64", "_Imm64"> {
@@ -1185,9 +1197,13 @@ def VSrc_v2f32 : RegOrImmOperand <"VS_64", "OPERAND_REG_IMM_V2FP32", "_Imm32">;
// with FMAMK/FMAAK
//===----------------------------------------------------------------------===//
+def VSrc_bf16_Deferred : RegOrImmOperand<"VS_32", "OPERAND_REG_IMM_BF16_DEFERRED", "_Deferred_Imm16">;
def VSrc_f16_Deferred : RegOrImmOperand<"VS_32", "OPERAND_REG_IMM_FP16_DEFERRED", "_Deferred_Imm16">;
def VSrc_f32_Deferred : RegOrImmOperand<"VS_32", "OPERAND_REG_IMM_FP32_DEFERRED", "_Deferred_Imm32">;
+def VSrcFake16_bf16_Lo128_Deferred : RegOrImmOperand<"VS_32_Lo128",
+ "OPERAND_REG_IMM_BF16_DEFERRED",
+ "_Deferred_Imm16">;
def VSrcFake16_f16_Lo128_Deferred : RegOrImmOperand<"VS_32_Lo128",
"OPERAND_REG_IMM_FP16_DEFERRED",
"_Deferred_Imm16">;
@@ -1252,19 +1268,23 @@ def ARegSrc_32 : AVOperand<AGPR_32, "decodeSrcA9", "OPW32">;
//===----------------------------------------------------------------------===//
def VCSrc_b16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_INT16", "_Imm16">;
+def VCSrc_bf16: RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_BF16", "_Imm16">;
def VCSrc_f16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_FP16", "_Imm16">;
def VCSrc_b32 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_INT32", "_Imm32">;
def VCSrc_f32 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_FP32", "_Imm32">;
def VCSrc_v2b16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_V2INT16", "_ImmV2I16">;
+def VCSrc_v2bf16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_V2BF16", "_ImmV2F16">;
def VCSrc_v2f16 : RegOrImmOperand <"VS_32", "OPERAND_REG_INLINE_C_V2FP16", "_ImmV2F16">;
//===----------------------------------------------------------------------===//
// VISrc_* Operands with a VGPR or an inline constant
//===----------------------------------------------------------------------===//
+def VISrc_64_bf16 : RegOrImmOperand <"VReg_64", "OPERAND_REG_INLINE_C_BF16", "_Imm16">;
def VISrc_64_f16 : RegOrImmOperand <"VReg_64", "OPERAND_REG_INLINE_C_FP16", "_Imm16">;
def VISrc_64_b32 : RegOrImmOperand <"VReg_64", "OPERAND_REG_INLINE_C_INT32", "_Imm32">;
def VISrc_64_f64 : RegOrImmOperand <"VReg_64", "OPERAND_REG_INLINE_C_FP64", "_Imm64">;
+def VISrc_128_bf16 : RegOrImmOperand <"VReg_128", "OPERAND_REG_INLINE_C_BF16", "_Imm16">;
def VISrc_128_f16 : RegOrImmOperand <"VReg_128", "OPERAND_REG_INLINE_C_FP16", "_Imm16">;
def VISrc_128_b32 : RegOrImmOperand <"VReg_128", "OPERAND_REG_INLINE_C_INT32", "_Imm32">;
def VISrc_128_f32 : RegOrImmOperand <"VReg_128", "OPERAND_REG_INLINE_C_FP32", "_Imm32">;
diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h
index f24b9f0e3615de..1542e5f531001a 100644
--- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h
+++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h
@@ -1321,17 +1321,24 @@ inline unsigned getOperandSize(const MCOperandInfo &OpInfo) {
return 8;
case AMDGPU::OPERAND_REG_IMM_INT16:
+ case AMDGPU::OPERAND_REG_IMM_BF16:
case AMDGPU::OPERAND_REG_IMM_FP16:
+ case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
+ case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
+ case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
+ case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
+ case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
case AMDGPU::OPERAND_REG_IMM_V2INT16:
+ case AMDGPU::OPERAND_REG_IMM_V2BF16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
return 2;
diff --git a/llvm/lib/Target/AMDGPU/VOP3Instructions.td b/llvm/lib/Target/AMDGPU/VOP3Instructions.td
index 8d965d3b9041d5..35cffa22f45929 100644
--- a/llvm/lib/Target/AMDGPU/VOP3Instructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3Instructions.td
@@ -904,7 +904,7 @@ let SubtargetPredicate = isGFX12Plus, ReadsModeReg = 0 in {
let SubtargetPredicate = HasDot9Insts, IsDOT=1 in {
defm V_DOT2_F16_F16 : VOP3Inst<"v_dot2_f16_f16", VOP3_DOT_Profile<VOP_F16_V2F16_V2F16_F16>, int_amdgcn_fdot2_f16_f16>;
- defm V_DOT2_BF16_BF16 : VOP3Inst<"v_dot2_bf16_bf16", VOP3_DOT_Profile<VOP_I16_V2I16_V2I16_I16>, int_amdgcn_fdot2_bf16_bf16>;
+ defm V_DOT2_BF16_BF16 : VOP3Inst<"v_dot2_bf16_bf16", VOP3_DOT_Profile<VOP_BF16_V2BF16_V2BF16_BF16>, int_amdgcn_fdot2_bf16_bf16>;
}
class VOP_Pseudo_Scalar<RegisterClass Dst, RegisterOperand SrcOp,
diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index ef14a587c42e79..5686b2e539e31d 100644
--- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -396,7 +396,7 @@ defm V_DOT8_I32_I4 : VOP3PInst<"v_dot8_i32_i4",
} // End OtherPredicates = [HasDot1Insts]
def DOT2_BF16_Profile
- : VOP3P_Profile<VOP_F32_V2I16_V2I16_F32, VOP3_REGULAR, /*HasDPP*/ 1> {
+ : VOP3P_Profile<VOP_F32_V2BF16_V2BF16_F32, VOP3_REGULAR, /*HasDPP*/ 1> {
let HasSrc1Mods = 1;
}
diff --git a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.bf16.bf16.ll b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.bf16.bf16.ll
index 645e00b6a4a819..b1b44983e9daf8 100644
--- a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.bf16.bf16.ll
+++ b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.bf16.bf16.ll
@@ -2,7 +2,7 @@
; RUN: llc -mtriple=amdgcn -mcpu=gfx1100 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=GFX11,SDAG-GFX11
; RUN: llc -global-isel -mtriple=amdgcn -mcpu=gfx1100 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=GFX11,GISEL-GFX11
-declare i16 @llvm.amdgcn.fdot2.bf16.bf16(<2 x i16> %a, <2 x i16> %b, i16 %c)
+declare bfloat @llvm.amdgcn.fdot2.bf16.bf16(<2 x bfloat> %a, <2 x bfloat> %b, bfloat %c)
define amdgpu_kernel void @test_llvm_amdgcn_fdot2_bf16_bf16(
; GFX11-LABEL: test_llvm_amdgcn_fdot2_bf16_bf16:
@@ -24,11 +24,11 @@ define amdgpu_kernel void @test_llvm_amdgcn_fdot2_bf16_bf16(
ptr addrspace(1) %b,
ptr addrspace(1) %c) {
entry:
- %a.val = load <2 x i16>, ptr addrspace(1) %a
- %b.val = load <2 x i16>, ptr addrspace(1) %b
- %c.val = load i16, ptr addrspace(1) %c
- %r.val = call i16 @llvm.amdgcn.fdot2.bf16.bf16(<2 x i16> %a.val, <2 x i16> %b.val, i16 %c.val)
- store i16 %r.val, ptr addrspace(1) %r
+ %a.val = load <2 x bfloat>, ptr addrspace(1) %a
+ %b.val = load <2 x bfloat>, ptr addrspace(1) %b
+ %c.val = load bfloat, ptr addrspace(1) %c
+ %r.val = call bfloat @llvm.amdgcn.fdot2.bf16.bf16(<2 x bfloat> %a.val, <2 x bfloat> %b.val, bfloat %c.val)
+ store bfloat %r.val, ptr addrspace(1) %r
ret void
}
@@ -61,14 +61,14 @@ define amdgpu_kernel void @test_llvm_amdgcn_fdot2_bf16_bf16_dpp(
ptr addrspace(5) %b,
ptr addrspace(5) %c) {
entry:
- %a.val = load <2 x i16>, ptr addrspace(5) %a
- %b.val = load <2 x i16>, ptr addrspace(5) %b
- %c.val = load i16, ptr addrspace(5) %c
- %a.val.i32 = bitcast <2 x i16> %a.val to i32
+ %a.val = load <2 x bfloat>, ptr addrspace(5) %a
+ %b.val = load <2 x bfloat>, ptr addrspace(5) %b
+ %c.val = load bfloat, ptr addrspace(5) %c
+ %a.val.i32 = bitcast <2 x bfloat> %a.val to i32
%dpp = call i32 @llvm.amdgcn.update.dpp.i32(i32 %a.val.i32, i32 %a.val.i32, i32 1, i32 15, i32 15, i1 1)
- %a.val.dpp.v2i16 = bitcast i32 %dpp to <2 x i16>
- %r.val = call i16 @llvm.amdgcn.fdot2.bf16.bf16(<2 x i16> %a.val.dpp.v2i16, <2 x i16> %b.val, i16 %c.val)
- store i16 %r.val, ptr addrspace(5) %r
+ %a.val.dpp.v2bfloat = bitcast i32 %dpp to <2 x bfloat>
+ %r.val = call bfloat @llvm.amdgcn.fdot2.bf16.bf16(<2 x bfloat> %a.val.dpp.v2bfloat, <2 x bfloat> %b.val, bfloat %c.val)
+ store bfloat %r.val, ptr addrspace(5) %r
ret void
}
@@ -79,17 +79,17 @@ define amdgpu_ps void @test_llvm_amdgcn_fdot2_bf16_bf16_sis(
; GFX11: ; %bb.0: ; %entry
; GFX11-NEXT: v_mov_b32_e32 v2, s1
; GFX11-NEXT: s_delay_alu instid0(VALU_DEP_1)
-; GFX11-NEXT: v_dot2_bf16_bf16 v2, s0, 0x10001, v2
+; GFX11-NEXT: v_dot2_bf16_bf16 v2, s0, 0x3f803f80, v2
; GFX11-NEXT: global_store_b16 v[0:1], v2, off
; GFX11-NEXT: s_nop 0
; GFX11-NEXT: s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
; GFX11-NEXT: s_endpgm
ptr addrspace(1) %r,
- <2 x i16> inreg %a,
- i16 inreg %c) {
+ <2 x bfloat> inreg %a,
+ bfloat inreg %c) {
entry:
- %r.val = call i16 @llvm.amdgcn.fdot2.bf16.bf16(<2 x i16> %a, <2 x i16> <i16 1, i16 1>, i16 %c)
- store i16 %r.val, ptr addrspace(1) %r
+ %r.val = call bfloat @llvm.amdgcn.fdot2.bf16.bf16(<2 x bfloat> %a, <2 x bfloat> <bfloat 1.0, bfloat 1.0>, bfloat %c)
+ store bfloat %r.val, ptr addrspace(1) %r
ret void
}
diff --git a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.f32.bf16.ll b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.f32.bf16.ll
index 367ff57bae2fd6..e51b1d2da2e414 100644
--- a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.f32.bf16.ll
+++ b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.f32.bf16.ll
@@ -2,7 +2,7 @@
; RUN: llc -mtriple=amdgcn -mcpu=gfx1100 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=GFX11
; RUN: llc -global-isel -mtriple=amdgcn -mcpu=gfx1100 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=GFX11
-declare float @llvm.amdgcn.fdot2.f32.bf16(<2 x i16> %a, <2 x i16> %b, float %c, i1 %clamp)
+declare float @llvm.amdgcn.fdot2.f32.bf16(<2 x bfloat> %a, <2 x bfloat> %b, float %c, i1 %clamp)
define amdgpu_kernel void @test_llvm_amdgcn_fdot2_f32_bf16_clamp(
; GFX11-LABEL: test_llvm_amdgcn_fdot2_f32_bf16_clamp:
@@ -25,10 +25,10 @@ define amdgpu_kernel void @test_llvm_amdgcn_fdot2_f32_bf16_clamp(
ptr addrspace(1) %b,
ptr addrspace(1) %c) {
entry:
- %a.val = load <2 x i16>, ptr addrspace(1) %a
- %b.val = load <2 x i16>, ptr addrspace(1) %b
+ %a.val = load <2 x bfloat>, ptr addrspace(1) %a
+ %b.val = load <2 x bfloat>, ptr addrspace(1) %b
%c.val = load float, ptr addrspace(1) %c
- %r.val = call float @llvm.amdgcn.fdot2.f32.bf16(<2 x i16> %a.val, <2 x i16> %b.val, float %c.val, i1 1)
+ %r.val = call float @llvm.amdgcn.fdot2.f32.bf16(<2 x bfloat> %a.val, <2 x bfloat> %b.val, float %c.val, i1 1)
store float %r.val, ptr addrspace(1) %r
ret void
}
@@ -55,10 +55,10 @@ define amdgpu_kernel void @test_llvm_amdgcn_fdot2_f32_bf16_no_clamp(
ptr addrspace(1) %b,
ptr addrspace(1) %c) {
entry:
- %a.val = load <2 x i16>, ptr addrspace(1) %a
- %b.val = load <2 x i16>, ptr addrspace(1) %b
+ %a.val = load <2 x bfloat>, ptr addrspace(1) %a
+ %b.val = load <2 x bfloat>, ptr addrspace(1) %b
%c.val = load float, ptr addrspace(1) %c
- %r.val = call float @llvm.amdgcn.fdot2.f32.bf16(<2 x i16> %a.val, <2 x i16> %b.val, float %c.val, i1 0)
+ %r.val = call float @llvm.amdgcn.fdot2.f32.bf16(<2 x bfloat> %a.val, <2 x bfloat> %b.val, float %c.val, i1 0)
store float %r.val, ptr addrspace(1) %r
ret void
}
diff --git a/llvm/test/MC/AMDGPU/bf16_imm.s b/llvm/test/MC/AMDGPU/bf16_imm.s
new file mode 100644
index 00000000000000..479540ffb8c0ae
--- /dev/null
+++ b/llvm/test/MC/AMDGPU/bf16_imm.s
@@ -0,0 +1,8 @@
+// RUN: llvm-mc -arch=amdgcn -mcpu=gfx1100 -show-encoding %s | FileCheck %s
+// RUN: llvm-mc -arch=amdgcn -mcpu=gfx1200 -show-encoding %s | FileCheck %s
+
+v_dot2_bf16_bf16 v5, v1, v2, 100.0
+// CHECK: v_dot2_bf16_bf16 v5, v1, v2, 0x42c8 ; encoding: [0x05,0x00,0x67,0xd6,0x01,0x05,0xfe,0x03,0xc8,0x42,0x00,0x00]
+
+v_dot2_bf16_bf16 v5, v1, v2, 1.0
+// v_dot2_bf16_bf16 v5, v1, v2, 0x3f80 ; encoding: [0x05,0x00,0x67,0xd6,0x01,0x05,0xfe,0x03,0x80,0x3f,0x00,0x00]
More information about the llvm-commits
mailing list