[PATCH] D144911: adding bf16 support to NVPTX

Artem Belevich via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 3 16:38:10 PDT 2023


tra added a comment.

Overall the patch looks reasonable.

The only thing I'm not quite happy about is the introduction of yet another set of register class aliases mapping to the same actual .b16/.b32 register types.

Another potential issue is that the patch may make the new builtins available unconditionally in clang. Maybe. If you can test whether clang allows use of the new builtins even when compiling for an old GPU, that would help to tell if that's indeed a problem. If clang can't see the builtins introduced by the patch, then we're OK. If it does accept them, then we'll end up generates IR which can't be compiled by LLVM. In that case you will need to include clang-side plumbing for those builtins using `TARGET_BUILTIN()` with appropriate constraints.



================
Comment at: llvm/include/llvm/IR/IntrinsicsNVVM.td:604
       def int_nvvm_f # operation # variant :
         ClangBuiltin<!strconcat("__nvvm_f", operation, variant)>,
         DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_i16_ty, llvm_i16_ty],
----------------
Availability of these new instructions is conditional on specific CUDA version and the GPU variant we're compiling for,
Such builtins are normally implemented on the clang size as a `TARGET_BUILTIN()` with appropriate constraints.

Without that `ClangBuiltin` may automatically add enough glue to make them available in clang unconditionally, which would result in compiler crashing if a user tries to use one of those builtins with a wrong GPU or CUDA version. We want to emit a diagnostics, not cause a compiler crash.

Usually such related LLVM and clang changes should be part of the same patch.

This applies to the new intrinsic variants added below, too.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:414-415
   addRegisterClass(MVT::v2f16, &NVPTX::Float16x2RegsRegClass);
-  addRegisterClass(MVT::bf16, &NVPTX::Float16RegsRegClass);
-  addRegisterClass(MVT::v2bf16, &NVPTX::Float16x2RegsRegClass);
+  addRegisterClass(MVT::bf16, &NVPTX::BFloat16RegsRegClass);
+  addRegisterClass(MVT::v2bf16, &NVPTX::BFloat16x2RegsRegClass);
 
----------------
Is there a particular reason we need to create another register class for `.b16` and `.b32` registers? 
I think ideally register classes should represent the actual register types available in PTX. While f16/bf16, etc are operation types, they are not *register* types.

One of the side effects of adding multiple aliases for the actual register types that we end up with LLVM generating unnecessary moves just to convert between different register classes.

We've had that pesky behavior with fp16 and I was attempting to avoid creating new redundant register classes for bf16 when I've added storage-only support for bf16. The name `Float16*RegsRegClass` should probably get a better name to reflect that it represents an opaque 16 or 32-bit register, but it did represent those registers just fine.



================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:633
     setOperationAction(Op, MVT::f64, Legal);
     setOperationAction(Op, MVT::v2f16, Expand);
   }
----------------
do you want to add bf16x2 here?


================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:680-688
+  for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
+    setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Promote), Promote);
+    setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
+  }
+  for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
+    setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Expand), Expand);
+    setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand));
----------------
These loops should be coalesced with their FP siblings above.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:686
+    setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Expand), Expand);
+    setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand));
+    setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
----------------
I think f32 variant may have been included here by mistake.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:1304
     return TypeSplitVector;
-  if (VT == MVT::v2f16)
+  if (VT == MVT::v2f16 || VT == MVT::v2bf16)
     return TypeLegal;
----------------
I think we could use a couple of simple helper functions to predicate things common for operations on .b16 and .b32 types like f16 and bf16. Enumerating all involved types gets repetitive and error-prone otherwise.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:2103
+             ? DAG.getNode(ISD::BITCAST, SDLoc(Op), MVT::v2bf16, Const)
+             : DAG.getNode(ISD::BITCAST, SDLoc(Op), MVT::v2f16, Const);
 }
----------------
I think it could be simplified to just

`return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const);`


================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:2611
       NumElts /= 2;
-      for (unsigned i = 0; i < NumElts; ++i) {
-        SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val,
-                                 DAG.getIntPtrConstant(i * 2, DL));
-        SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val,
-                                 DAG.getIntPtrConstant(i * 2 + 1, DL));
-        SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2f16, E0, E1);
-        Ops.push_back(V2);
+      if (EltVT == MVT::f16) {
+        for (unsigned i = 0; i < NumElts; ++i) {
----------------
I'd suggest choosing the scalar and v2 types and then using the same loop to generate code. 
Actually, `EltVT` already has the right scalar type, so the v2 type can be just derive it with `getVectorVT(EltVT, 2)`. I don't think we need to copy/paste the whole loop here.



================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp:54-63
   } else if (DestRC == &NVPTX::Float16RegsRegClass) {
     Op = (SrcRC == &NVPTX::Float16RegsRegClass ? NVPTX::FMOV16rr
                                                : NVPTX::BITCONVERT_16_I2F);
   } else if (DestRC == &NVPTX::Float16x2RegsRegClass) {
     Op = NVPTX::IMOV32rr;
+  } else if (DestRC == &NVPTX::BFloat16RegsRegClass) {
+    Op = (SrcRC == &NVPTX::BFloat16RegsRegClass ? NVPTX::BFMOV16rr
----------------
Those are the moves we really do not need.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:659-666
 multiclass CVT_FROM_FLOAT_SM80<string FromName, RegisterClass RC> {
     def _f32 :
       NVPTXInst<(outs RC:$dst),
                 (ins Float32Regs:$src, CvtMode:$mode),
                 !strconcat("cvt${mode:base}${mode:relu}.",
                 FromName, ".f32 \t$dst, $src;"), []>,
                 Requires<[hasPTX70, hasSM80]>;
----------------
I think this multiclass can be deleted now.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:756-761
+def SELP_bf16x2rr :
+    NVPTXInst<(outs BFloat16x2Regs:$dst),
+              (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b, Int1Regs:$p),
+              "selp.b32 \t$dst, $a, $b, $p;",
+              [(set BFloat16x2Regs:$dst,
+                    (select Int1Regs:$p, (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>;
----------------
Another addition which would not be necessary if we were operating on generic .b16/.b32 registers.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:2094
 
-  // f32 -> pred
+  //bf16 -> pred
+  def : Pat<(i1 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))),
----------------
nit. Needs space between the "//" and the comment itself. Same in the f32 block below.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:2159
+  
+    // bf16 -> i32
+  def : Pat<(i32 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))),
----------------
Nit: comment alignment is off. Proabably due to use of <TAB>. Please make sure your changes only use spaces.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:2710
   defm LD_f16 : LD<Float16Regs>;
+  defm LD_bf16 : LD<BFloat16Regs>;
   defm LD_f16x2 : LD<Float16x2Regs>;
----------------
Another instance of redundancy required by redundant register classes.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:3356-3391
+  def BF16x2toBF16_0 : NVPTXInst<(outs BFloat16Regs:$dst),
+                               (ins BFloat16x2Regs:$src),
+                               "{{ .reg .b16 \t%tmp_hi;\n\t"
+                               "  mov.b32 \t{$dst, %tmp_hi}, $src; }}",
+                               [(set BFloat16Regs:$dst,
+                                 (extractelt (v2bf16 BFloat16x2Regs:$src), 0))]>;
+  def BF16x2toBF16_1 : NVPTXInst<(outs BFloat16Regs:$dst),
----------------
Ditto.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td:1001-1002
 
-    FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, Int16Regs, [hasPTX70, hasSM80]>,
-    FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, Int16Regs,
-      [hasPTX70, hasSM80]>,
+    // FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, BFloat16Regs,
+    //   [hasPTX70, hasSM80]>,
 
----------------
Should it be removed? Uncommented?


================
Comment at: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td:1257-1274
+// def : Pat<(int_nvvm_ff2f16x2_rn Float32Regs:$a, Float32Regs:$b),
+//           (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
+// def : Pat<(int_nvvm_ff2f16x2_rn_relu Float32Regs:$a, Float32Regs:$b),
+//           (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
+// def : Pat<(int_nvvm_ff2f16x2_rz Float32Regs:$a, Float32Regs:$b),
+//           (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>;
+// def : Pat<(int_nvvm_ff2f16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
----------------
Remove? Uncomment?


================
Comment at: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td:1390-1394
+// def : Pat<(int_nvvm_bf2h_rn_ftz Float32Regs:$a),
+//           (BITCONVERT_16_BF2I (CVT_bf16_f32 Float32Regs:$a, CvtRN_FTZ))>;
+// def : Pat<(int_nvvm_f2h_rn BFloat16Regs:$a),
+//           (BITCONVERT_16_BF2I (CVT_bf16_f32 BFloat16Regs:$a, CvtRN))>;
+
----------------
Ditto.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D144911/new/

https://reviews.llvm.org/D144911



More information about the llvm-commits mailing list