[llvm] 3699811 - [AMDGPU] Handle bf16 operands the same way as f16. NFC. (#77826)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 11 21:08:23 PST 2024
Author: Stanislav Mekhanoshin
Date: 2024-01-11T21:08:19-08:00
New Revision: 369981181ffa75a8500416982417662f1fa04704
URL: https://github.com/llvm/llvm-project/commit/369981181ffa75a8500416982417662f1fa04704
DIFF: https://github.com/llvm/llvm-project/commit/369981181ffa75a8500416982417662f1fa04704.diff
LOG: [AMDGPU] Handle bf16 operands the same way as f16. NFC. (#77826)
This is infrastructure change which shall allow use of bf16 operands
with instruction definitions.
Added:
Modified:
llvm/lib/Target/AMDGPU/SIInstrInfo.td
Removed:
################################################################################
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
index 04c92155f5aada..441d72cc173f5b 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
@@ -284,12 +284,17 @@ def SIfptrunc_round_downward : SDNode<"AMDGPUISD::FPTRUNC_ROUND_DOWNWARD",
// Returns 1 if the source arguments have modifiers, 0 if they do not.
class isFloatType<ValueType SrcVT> {
bit ret = !or(!eq(SrcVT.Value, f16.Value),
+ !eq(SrcVT.Value, bf16.Value),
!eq(SrcVT.Value, f32.Value),
!eq(SrcVT.Value, f64.Value),
!eq(SrcVT.Value, v2f16.Value),
+ !eq(SrcVT.Value, v2bf16.Value),
!eq(SrcVT.Value, v4f16.Value),
+ !eq(SrcVT.Value, v4bf16.Value),
!eq(SrcVT.Value, v8f16.Value),
+ !eq(SrcVT.Value, v8bf16.Value),
!eq(SrcVT.Value, v16f16.Value),
+ !eq(SrcVT.Value, v16bf16.Value),
!eq(SrcVT.Value, v2f32.Value),
!eq(SrcVT.Value, v4f32.Value),
!eq(SrcVT.Value, v8f32.Value),
@@ -314,7 +319,9 @@ class isIntType<ValueType SrcVT> {
class isPackedType<ValueType SrcVT> {
bit ret = !or(!eq(SrcVT.Value, v2i16.Value),
!eq(SrcVT.Value, v2f16.Value),
+ !eq(SrcVT.Value, v2bf16.Value),
!eq(SrcVT.Value, v4f16.Value),
+ !eq(SrcVT.Value, v4bf16.Value),
!eq(SrcVT.Value, v2i32.Value),
!eq(SrcVT.Value, v2f32.Value),
!eq(SrcVT.Value, v4i32.Value),
@@ -1495,14 +1502,14 @@ class getVOPSrc0ForVT<ValueType VT, bit IsTrue16, bit IsFake16 = 1> {
!if(isFP,
!if(!eq(VT.Size, 64),
VSrc_f64,
- !if(!eq(VT.Value, f16.Value),
+ !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
!if(IsTrue16,
!if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128),
VSrc_f16
),
- !if(!eq(VT.Value, v2f16.Value),
+ !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
VSrc_v2f16,
- !if(!eq(VT.Value, v4f16.Value),
+ !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
AVSrc_64,
VSrc_f32
)
@@ -1576,11 +1583,11 @@ class getVOP3SrcForVT<ValueType VT, bit IsTrue16 = 0> {
!if(!eq(VT.Value, i1.Value),
SSrc_i1,
!if(isFP,
- !if(!eq(VT.Value, f16.Value),
+ !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
!if(IsTrue16, VSrcT_f16, VSrc_f16),
- !if(!eq(VT.Value, v2f16.Value),
+ !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
VSrc_v2f16,
- !if(!eq(VT.Value, v4f16.Value),
+ !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
AVSrc_64,
VSrc_f32
)
@@ -1605,8 +1612,8 @@ class getVOP3DPPSrcForVT<ValueType VT> {
RegisterOperand ret =
!if (!eq(VT.Value, i1.Value), SSrc_i1,
!if (isFP,
- !if (!eq(VT.Value, f16.Value), VCSrc_f16,
- !if (!eq(VT.Value, v2f16.Value), VCSrc_v2f16, VCSrc_f32)),
+ !if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)), VCSrc_f16,
+ !if (!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)), VCSrc_v2f16, VCSrc_f32)),
!if (!eq(VT.Value, i16.Value), VCSrc_b16,
!if (!eq(VT.Value, v2i16.Value), VCSrc_v2b16,
VCSrc_b32))));
@@ -1615,22 +1622,27 @@ class getVOP3DPPSrcForVT<ValueType VT> {
// Float or packed int
class isModifierType<ValueType SrcVT> {
bit ret = !or(!eq(SrcVT.Value, f16.Value),
+ !eq(SrcVT.Value, bf16.Value),
!eq(SrcVT.Value, f32.Value),
!eq(SrcVT.Value, f64.Value),
!eq(SrcVT.Value, v2f16.Value),
!eq(SrcVT.Value, v2i16.Value),
+ !eq(SrcVT.Value, v2bf16.Value),
!eq(SrcVT.Value, v2f32.Value),
!eq(SrcVT.Value, v2i32.Value),
!eq(SrcVT.Value, v4f16.Value),
!eq(SrcVT.Value, v4i16.Value),
+ !eq(SrcVT.Value, v4bf16.Value),
!eq(SrcVT.Value, v4f32.Value),
!eq(SrcVT.Value, v4i32.Value),
!eq(SrcVT.Value, v8f16.Value),
!eq(SrcVT.Value, v8i16.Value),
+ !eq(SrcVT.Value, v8bf16.Value),
!eq(SrcVT.Value, v8f32.Value),
!eq(SrcVT.Value, v8i32.Value),
!eq(SrcVT.Value, v16f16.Value),
- !eq(SrcVT.Value, v16i16.Value));
+ !eq(SrcVT.Value, v16i16.Value),
+ !eq(SrcVT.Value, v16bf16.Value));
}
// Return type of input modifiers operand for specified input operand
@@ -1646,7 +1658,8 @@ class getSrcMod <ValueType VT, bit IsTrue16 = 0> {
}
class getOpSelMod <ValueType VT> {
- Operand ret = !if(!eq(VT.Value, f16.Value), FP16InputMods, IntOpSelMods);
+ Operand ret = !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
+ FP16InputMods, IntOpSelMods);
}
// Return type of input modifiers operand specified input operand for DPP
@@ -1659,8 +1672,8 @@ class getSrcModDPP_t16 <ValueType VT> {
bit isFP = isFloatType<VT>.ret;
Operand ret =
!if (isFP,
- !if (!eq(VT.Value, f16.Value), FPT16VRegInputMods,
- FPVRegInputMods),
+ !if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
+ FPT16VRegInputMods, FPVRegInputMods),
!if (!eq(VT.Value, i16.Value), IntT16VRegInputMods,
IntVRegInputMods));
}
@@ -1671,8 +1684,8 @@ class getSrcModVOP3DPP <ValueType VT> {
bit isPacked = isPackedType<VT>.ret;
Operand ret =
!if (isFP,
- !if (!eq(VT.Value, f16.Value), FP16VCSrcInputMods,
- FP32VCSrcInputMods),
+ !if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
+ FP16VCSrcInputMods, FP32VCSrcInputMods),
Int32VCSrcInputMods);
}
@@ -1681,7 +1694,8 @@ class getSrcModSDWA <ValueType VT> {
Operand ret = !if(!eq(VT.Value, f16.Value), FP16SDWAInputMods,
!if(!eq(VT.Value, f32.Value), FP32SDWAInputMods,
!if(!eq(VT.Value, i16.Value), Int16SDWAInputMods,
- Int32SDWAInputMods)));
+ !if(!eq(VT.Value, bf16.Value), FP16SDWAInputMods,
+ Int32SDWAInputMods))));
}
// Returns the input arguments for VOP[12C] instructions for the given SrcVT.
More information about the llvm-commits
mailing list