[PATCH] D144911: adding bf16 support to NVPTX
Artem Belevich via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Fri May 26 16:15:18 PDT 2023
tra added inline comments.
================
Comment at: llvm/include/llvm/IR/IntrinsicsNVVM.td:604
def int_nvvm_f # operation # variant :
ClangBuiltin<!strconcat("__nvvm_f", operation, variant)>,
DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_i16_ty, llvm_i16_ty],
----------------
tra wrote:
> Availability of these new instructions is conditional on specific CUDA version and the GPU variant we're compiling for,
> Such builtins are normally implemented on the clang size as a `TARGET_BUILTIN()` with appropriate constraints.
>
> Without that `ClangBuiltin` may automatically add enough glue to make them available in clang unconditionally, which would result in compiler crashing if a user tries to use one of those builtins with a wrong GPU or CUDA version. We want to emit a diagnostics, not cause a compiler crash.
>
> Usually such related LLVM and clang changes should be part of the same patch.
>
> This applies to the new intrinsic variants added below, too.
I do not think it's is done.
Can you check what happens if you try to call any of bf16 builtins while compiling for sm_60? Ideally we should produce a sensible error that the builtin is not available.
I suspect we will fail in LLVM when we'll fail to lower the intrinsic, ot in nvptx if we've managed to lower it to an instruction unsupported by sm_60.
================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp:1297-1304
if (EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) {
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
EltVT = MVT::v2f16;
NumElts /= 2;
+ } else if (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) {
+ assert(NumElts % 2 == 0 && "Vector must have even number of elements");
+ EltVT = MVT::v2bf16;
----------------
These could be collapsed into
```
if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ) {
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
EltVT = N->getValueType(0);
NumElts /= 2;
}
```
================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:147-153
+ switch (VT.SimpleTy) {
+ default:
+ return false;
+ case MVT::v2f16:
+ case MVT::v2bf16:
+ return true;
+ }
----------------
It can be simplified to just `return (VT.SimpleTy == MVT::v2f16 || VT.SimpleTy == MVT::v2bf16);`
================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:156
+
+static bool Isf16Orbf16Type(MVT VT) {
+ switch (VT.SimpleTy) {
----------------
ditto.
================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:623
+ for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
+ setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
----------------
Fold it into the loop above.
================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:699
}
+ for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
+ setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Promote), Promote);
----------------
Fold into the loop processing `{ISD::FMINNUM, ISD::FMAXNUM}` above.
Also, do we want/need to add bf16 handling for `{ISD::FMINIMUM, ISD::FMAXIMUM}` too?
The LLVM's choice of constants `FMINIMUM` vs `FMINNUM` is rather unfortunate -- it's so easy to misread one for another.
================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:700-703
+ setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Promote), Promote);
+ setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
+ setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Expand), Expand);
+ setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
----------------
I'm not sure what's going on here. Should it be Promote for bf16 and Expand for v2bf16? Why do we have two other entries, one of them trying to Expand bf16?
================
Comment at: llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td:65-66
+def Float16x2Regs : NVPTXRegClass<[v2f16], 32, (add (sequence "HH%u", 0, 4))>;
+def BFloat16Regs : NVPTXRegClass<[bf16], 16, (add (sequence "H%u", 0, 4))>;
+def BFloat16x2Regs : NVPTXRegClass<[v2bf16], 32, (add (sequence "HH%u", 0, 4))>;
def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>;
----------------
I suspect this may be a problem.
What PTX do we end up generating if we have a function that needs to use both f16 and bf16 registers? I suspect we may end up with defining conflicting sets of registers.
I still do not think that we need a spearate register class for bf16 and both bf16 and fp16 should be using a generic opaque 16/32 bit register types (or, even better, generic Int16/Int32 registers.
RegClass accepts multiple type values, so it may be as simple as using `def Int16Regs : NVPTXRegClass<[i16,f16,bf16], 16, (add (sequence "RS%u", 0, 4))>;` and adjusting existing use cases.
That should probably be done as a separate patch as there would be a lot of churn unrelated to bf16 support. I'll check if it's doable next week.
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D144911/new/
https://reviews.llvm.org/D144911
More information about the llvm-commits
mailing list