[llvm] 3699811 - [AMDGPU] Handle bf16 operands the same way as f16. NFC. (#77826)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 11 21:08:23 PST 2024


Author: Stanislav Mekhanoshin
Date: 2024-01-11T21:08:19-08:00
New Revision: 369981181ffa75a8500416982417662f1fa04704

URL: https://github.com/llvm/llvm-project/commit/369981181ffa75a8500416982417662f1fa04704
DIFF: https://github.com/llvm/llvm-project/commit/369981181ffa75a8500416982417662f1fa04704.diff

LOG: [AMDGPU] Handle bf16 operands the same way as f16. NFC. (#77826)

This is infrastructure change which shall allow use of bf16 operands
with instruction definitions.

Added: 
    

Modified: 
    llvm/lib/Target/AMDGPU/SIInstrInfo.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
index 04c92155f5aada..441d72cc173f5b 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
@@ -284,12 +284,17 @@ def SIfptrunc_round_downward : SDNode<"AMDGPUISD::FPTRUNC_ROUND_DOWNWARD",
 // Returns 1 if the source arguments have modifiers, 0 if they do not.
 class isFloatType<ValueType SrcVT> {
   bit ret = !or(!eq(SrcVT.Value, f16.Value),
+                !eq(SrcVT.Value, bf16.Value),
                 !eq(SrcVT.Value, f32.Value),
                 !eq(SrcVT.Value, f64.Value),
                 !eq(SrcVT.Value, v2f16.Value),
+                !eq(SrcVT.Value, v2bf16.Value),
                 !eq(SrcVT.Value, v4f16.Value),
+                !eq(SrcVT.Value, v4bf16.Value),
                 !eq(SrcVT.Value, v8f16.Value),
+                !eq(SrcVT.Value, v8bf16.Value),
                 !eq(SrcVT.Value, v16f16.Value),
+                !eq(SrcVT.Value, v16bf16.Value),
                 !eq(SrcVT.Value, v2f32.Value),
                 !eq(SrcVT.Value, v4f32.Value),
                 !eq(SrcVT.Value, v8f32.Value),
@@ -314,7 +319,9 @@ class isIntType<ValueType SrcVT> {
 class isPackedType<ValueType SrcVT> {
   bit ret = !or(!eq(SrcVT.Value, v2i16.Value),
                 !eq(SrcVT.Value, v2f16.Value),
+                !eq(SrcVT.Value, v2bf16.Value),
                 !eq(SrcVT.Value, v4f16.Value),
+                !eq(SrcVT.Value, v4bf16.Value),
                 !eq(SrcVT.Value, v2i32.Value),
                 !eq(SrcVT.Value, v2f32.Value),
                 !eq(SrcVT.Value, v4i32.Value),
@@ -1495,14 +1502,14 @@ class getVOPSrc0ForVT<ValueType VT, bit IsTrue16, bit IsFake16 = 1> {
     !if(isFP,
       !if(!eq(VT.Size, 64),
          VSrc_f64,
-         !if(!eq(VT.Value, f16.Value),
+         !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
             !if(IsTrue16,
               !if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128),
               VSrc_f16
             ),
-            !if(!eq(VT.Value, v2f16.Value),
+            !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
                VSrc_v2f16,
-               !if(!eq(VT.Value, v4f16.Value),
+               !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
                  AVSrc_64,
                  VSrc_f32
                )
@@ -1576,11 +1583,11 @@ class getVOP3SrcForVT<ValueType VT, bit IsTrue16 = 0> {
         !if(!eq(VT.Value, i1.Value),
            SSrc_i1,
            !if(isFP,
-              !if(!eq(VT.Value, f16.Value),
+              !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
                  !if(IsTrue16, VSrcT_f16, VSrc_f16),
-                 !if(!eq(VT.Value, v2f16.Value),
+                 !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
                     VSrc_v2f16,
-                    !if(!eq(VT.Value, v4f16.Value),
+                    !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
                       AVSrc_64,
                       VSrc_f32
                     )
@@ -1605,8 +1612,8 @@ class getVOP3DPPSrcForVT<ValueType VT> {
   RegisterOperand ret =
       !if (!eq(VT.Value, i1.Value), SSrc_i1,
            !if (isFP,
-                !if (!eq(VT.Value, f16.Value), VCSrc_f16,
-                     !if (!eq(VT.Value, v2f16.Value), VCSrc_v2f16, VCSrc_f32)),
+                !if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)), VCSrc_f16,
+                     !if (!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)), VCSrc_v2f16, VCSrc_f32)),
                 !if (!eq(VT.Value, i16.Value), VCSrc_b16,
                      !if (!eq(VT.Value, v2i16.Value), VCSrc_v2b16,
                           VCSrc_b32))));
@@ -1615,22 +1622,27 @@ class getVOP3DPPSrcForVT<ValueType VT> {
 // Float or packed int
 class isModifierType<ValueType SrcVT> {
   bit ret = !or(!eq(SrcVT.Value, f16.Value),
+                !eq(SrcVT.Value, bf16.Value),
                 !eq(SrcVT.Value, f32.Value),
                 !eq(SrcVT.Value, f64.Value),
                 !eq(SrcVT.Value, v2f16.Value),
                 !eq(SrcVT.Value, v2i16.Value),
+                !eq(SrcVT.Value, v2bf16.Value),
                 !eq(SrcVT.Value, v2f32.Value),
                 !eq(SrcVT.Value, v2i32.Value),
                 !eq(SrcVT.Value, v4f16.Value),
                 !eq(SrcVT.Value, v4i16.Value),
+                !eq(SrcVT.Value, v4bf16.Value),
                 !eq(SrcVT.Value, v4f32.Value),
                 !eq(SrcVT.Value, v4i32.Value),
                 !eq(SrcVT.Value, v8f16.Value),
                 !eq(SrcVT.Value, v8i16.Value),
+                !eq(SrcVT.Value, v8bf16.Value),
                 !eq(SrcVT.Value, v8f32.Value),
                 !eq(SrcVT.Value, v8i32.Value),
                 !eq(SrcVT.Value, v16f16.Value),
-                !eq(SrcVT.Value, v16i16.Value));
+                !eq(SrcVT.Value, v16i16.Value),
+                !eq(SrcVT.Value, v16bf16.Value));
 }
 
 // Return type of input modifiers operand for specified input operand
@@ -1646,7 +1658,8 @@ class getSrcMod <ValueType VT, bit IsTrue16 = 0> {
 }
 
 class getOpSelMod <ValueType VT> {
-  Operand ret = !if(!eq(VT.Value, f16.Value), FP16InputMods, IntOpSelMods);
+  Operand ret = !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
+                    FP16InputMods, IntOpSelMods);
 }
 
 // Return type of input modifiers operand specified input operand for DPP
@@ -1659,8 +1672,8 @@ class getSrcModDPP_t16 <ValueType VT> {
   bit isFP = isFloatType<VT>.ret;
   Operand ret =
       !if (isFP,
-           !if (!eq(VT.Value, f16.Value), FPT16VRegInputMods,
-                FPVRegInputMods),
+           !if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
+                FPT16VRegInputMods, FPVRegInputMods),
            !if (!eq(VT.Value, i16.Value), IntT16VRegInputMods,
                 IntVRegInputMods));
 }
@@ -1671,8 +1684,8 @@ class getSrcModVOP3DPP <ValueType VT> {
   bit isPacked = isPackedType<VT>.ret;
   Operand ret =
       !if (isFP,
-           !if (!eq(VT.Value, f16.Value), FP16VCSrcInputMods,
-                FP32VCSrcInputMods),
+           !if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
+                FP16VCSrcInputMods, FP32VCSrcInputMods),
            Int32VCSrcInputMods);
 }
 
@@ -1681,7 +1694,8 @@ class getSrcModSDWA <ValueType VT> {
   Operand ret = !if(!eq(VT.Value, f16.Value), FP16SDWAInputMods,
                 !if(!eq(VT.Value, f32.Value), FP32SDWAInputMods,
                 !if(!eq(VT.Value, i16.Value), Int16SDWAInputMods,
-                Int32SDWAInputMods)));
+                !if(!eq(VT.Value, bf16.Value), FP16SDWAInputMods,
+                Int32SDWAInputMods))));
 }
 
 // Returns the input arguments for VOP[12C] instructions for the given SrcVT.


        


More information about the llvm-commits mailing list