[PATCH] D144911: adding bf16 support to NVPTX

Artem Belevich via Phabricator via cfe-commits cfe-commits at lists.llvm.org
Fri May 26 16:15:17 PDT 2023


tra added inline comments.


================
Comment at: llvm/include/llvm/IR/IntrinsicsNVVM.td:604
       def int_nvvm_f # operation # variant :
         ClangBuiltin<!strconcat("__nvvm_f", operation, variant)>,
         DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_i16_ty, llvm_i16_ty],
----------------
tra wrote:
> Availability of these new instructions is conditional on specific CUDA version and the GPU variant we're compiling for,
> Such builtins are normally implemented on the clang size as a `TARGET_BUILTIN()` with appropriate constraints.
> 
> Without that `ClangBuiltin` may automatically add enough glue to make them available in clang unconditionally, which would result in compiler crashing if a user tries to use one of those builtins with a wrong GPU or CUDA version. We want to emit a diagnostics, not cause a compiler crash.
> 
> Usually such related LLVM and clang changes should be part of the same patch.
> 
> This applies to the new intrinsic variants added below, too.
I do not think it's is done. 

Can you check what happens if you try to call any of bf16 builtins while compiling for sm_60? Ideally we should produce a sensible error that the builtin is not available.

I suspect we will fail in LLVM when we'll fail to lower the intrinsic, ot in nvptx if we've managed to lower it to an instruction unsupported by sm_60.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp:1297-1304
     if (EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) {
       assert(NumElts % 2 == 0 && "Vector must have even number of elements");
       EltVT = MVT::v2f16;
       NumElts /= 2;
+    } else if (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) {
+      assert(NumElts % 2 == 0 && "Vector must have even number of elements");
+      EltVT = MVT::v2bf16;
----------------
These could be collapsed into 
```
if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) || 
     (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ) {
      assert(NumElts % 2 == 0 && "Vector must have even number of elements");
      EltVT = N->getValueType(0);
      NumElts /= 2;
    }
```


================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:147-153
+  switch (VT.SimpleTy) {
+  default:
+    return false;
+  case MVT::v2f16:
+  case MVT::v2bf16:
+    return true;
+  }
----------------
It can be simplified to just `return (VT.SimpleTy == MVT::v2f16 || VT.SimpleTy == MVT::v2bf16);`



================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:156
+
+static bool Isf16Orbf16Type(MVT VT) {
+  switch (VT.SimpleTy) {
----------------
ditto.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:623
 
+  for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
+    setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
----------------
Fold it into the loop above.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:699
   }
+  for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
+    setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Promote), Promote);
----------------
Fold into the loop processing `{ISD::FMINNUM, ISD::FMAXNUM}` above.

Also, do we want/need to add bf16 handling for `{ISD::FMINIMUM, ISD::FMAXIMUM}` too?

The LLVM's choice of constants `FMINIMUM` vs `FMINNUM` is rather unfortunate -- it's so easy to misread one for another.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:700-703
+    setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Promote), Promote);
+    setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
+    setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Expand), Expand);
+    setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
----------------
I'm not sure what's going on here. Should it be Promote for bf16 and Expand for v2bf16? Why do we have two other entries, one of them trying to Expand bf16?


================
Comment at: llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td:65-66
+def Float16x2Regs : NVPTXRegClass<[v2f16], 32, (add (sequence "HH%u", 0, 4))>;
+def BFloat16Regs : NVPTXRegClass<[bf16], 16, (add (sequence "H%u", 0, 4))>;
+def BFloat16x2Regs : NVPTXRegClass<[v2bf16], 32, (add (sequence "HH%u", 0, 4))>;
 def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>;
----------------
I suspect this may be a problem.

What PTX do we end up generating if we have a function that needs to use both f16 and bf16 registers? I suspect we may end up with defining conflicting sets of registers.

I still do not think that we need a spearate register class for bf16 and both bf16 and fp16 should be using a generic opaque 16/32 bit register types (or, even better, generic Int16/Int32 registers. 

RegClass accepts multiple type values, so it may be as simple as using `def Int16Regs : NVPTXRegClass<[i16,f16,bf16], 16, (add (sequence "RS%u", 0, 4))>;` and adjusting existing use cases.

That should probably be done as a separate patch as there would be a lot of churn unrelated to bf16 support. I'll check if it's doable next week.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D144911/new/

https://reviews.llvm.org/D144911



More information about the cfe-commits mailing list