[PATCH] D144911: adding bf16 support to NVPTX
Artem Belevich via Phabricator via cfe-commits
cfe-commits at lists.llvm.org
Thu Jun 8 14:37:01 PDT 2023
tra added a comment.
Overall looks good with few minor nits and a couple of questions.
================
Comment at: llvm/include/llvm/IR/IntrinsicsNVVM.td:604
def int_nvvm_f # operation # variant :
ClangBuiltin<!strconcat("__nvvm_f", operation, variant)>,
DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_i16_ty, llvm_i16_ty],
----------------
tra wrote:
> tra wrote:
> > Availability of these new instructions is conditional on specific CUDA version and the GPU variant we're compiling for,
> > Such builtins are normally implemented on the clang size as a `TARGET_BUILTIN()` with appropriate constraints.
> >
> > Without that `ClangBuiltin` may automatically add enough glue to make them available in clang unconditionally, which would result in compiler crashing if a user tries to use one of those builtins with a wrong GPU or CUDA version. We want to emit a diagnostics, not cause a compiler crash.
> >
> > Usually such related LLVM and clang changes should be part of the same patch.
> >
> > This applies to the new intrinsic variants added below, too.
> I do not think it's is done.
>
> Can you check what happens if you try to call any of bf16 builtins while compiling for sm_60? Ideally we should produce a sensible error that the builtin is not available.
>
> I suspect we will fail in LLVM when we'll fail to lower the intrinsic, ot in nvptx if we've managed to lower it to an instruction unsupported by sm_60.
OK. We'll leave conditional clang builtin handling to be fixed separately as it's not directly related to this patch.
================
Comment at: llvm/include/llvm/IR/IntrinsicsNVVM.td:878
def int_nvvm_fma # variant : ClangBuiltin<!strconcat("__nvvm_fma", variant)>,
- DefaultAttrsIntrinsic<[llvm_i16_ty],
- [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty],
+ DefaultAttrsIntrinsic<[llvm_bfloat_ty],
+ [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty],
----------------
This changes signatures of existing intrinsics and builtins. While the change is correct, we should at least check that MLIR tests are still passing.
================
Comment at: llvm/include/llvm/IR/IntrinsicsNVVM.td:1244-1251
def int_nvvm_ff2bf16x2_rn : ClangBuiltin<"__nvvm_ff2bf16x2_rn">,
Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_ff2bf16x2_rn_relu : ClangBuiltin<"__nvvm_ff2bf16x2_rn_relu">,
Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_ff2bf16x2_rz : ClangBuiltin<"__nvvm_ff2bf16x2_rz">,
Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_ff2bf16x2_rz_relu : ClangBuiltin<"__nvvm_ff2bf16x2_rz_relu">,
----------------
We've removed the patterns matching these intrinsics in lib/Target/NVPTX/NVPTXIntrinsics.td so there's nothing to lower them to an instruction now. Was that intentional?
================
Comment at: llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp:64-69
+ case 9:
OS << "%h";
break;
case 8:
+ case 10:
OS << "%hh";
----------------
Looks like I've forgot to remove those cases in my regclass patch. Will fix it shortly.
================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:640
ISD::FROUNDEVEN, ISD::FTRUNC}) {
+ setOperationAction(Op, MVT::bf16, Legal);
setOperationAction(Op, MVT::f16, Legal);
----------------
Nit: sometimes bf16 variants are added above fp16 variants, sometimes after. It would be nice to do it consistently. I guess we should just do a cleanup patch sorting these blocks in type order.
================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:2514-2516
+ if ((Isv2f16Orv2bf16Type(VT.getSimpleVT())) &&
!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
VT, *Store->getMemOperand()))
----------------
Unnecessary `()`around `Isv2f16Orv2bf16Type(VT.getSimpleVT())`
================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:2601
// store them with st.v4.b32.
- assert((EltVT == MVT::f16 || EltVT == MVT::bf16) &&
+ assert((Isf16Orbf16Type(EltVT.getSimpleVT())) &&
"Wrong type for the vector.");
----------------
Ditto.
================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:1316-1326
defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>;
defm FMA16 : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>;
defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
defm FMA16x2 : FMA_F16<"fma.rn.f16x2", v2f16, Int32Regs, True>;
+defm BFMA16_ftz : FMA_BF16<"fma.rn.ftz.bf16", bf16, Int16Regs, doF32FTZ>;
+defm BFMA16 : FMA_BF16<"fma.rn.bf16", bf16, Int16Regs, True>;
+defm BFMA16x2_ftz : FMA_BF16<"fma.rn.ftz.bf16x2", v2bf16, Int32Regs, doF32FTZ>;
----------------
Nit: align ':' across the block.
================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:1892-1893
"mov.b16 \t$dst, $src;", []>;
+ def BFMOV16rr : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
+ "mov.b16 \t$dst, $src;", []>;
def FMOV32rr : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
----------------
we now have `IMOVB16rr` for this.
================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:2811-2814
defm LDV_f16 : LD_VEC<Int16Regs>;
defm LDV_f16x2 : LD_VEC<Int32Regs>;
+ defm LDV_bf16 : LD_VEC<Int16Regs>;
+ defm LDV_bf16x2 : LD_VEC<Int32Regs>;
----------------
Those should no longer be necessary. LDV_i16/LDV_i32 should do the job.
================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:2908-2911
defm STV_f16 : ST_VEC<Int16Regs>;
defm STV_f16x2 : ST_VEC<Int32Regs>;
+ defm STV_bf16 : ST_VEC<Int16Regs>;
+ defm STV_bf16x2 : ST_VEC<Int32Regs>;
----------------
Ditto.
================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:659-666
multiclass CVT_FROM_FLOAT_SM80<string FromName, RegisterClass RC> {
def _f32 :
NVPTXInst<(outs RC:$dst),
(ins Float32Regs:$src, CvtMode:$mode),
!strconcat("cvt${mode:base}${mode:relu}.",
FromName, ".f32 \t$dst, $src;"), []>,
Requires<[hasPTX70, hasSM80]>;
----------------
tra wrote:
> I think this multiclass can be deleted now.
Unused `CVT_FROM_FLOAT_SM80` should be removed, still.
================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:2094
- // f32 -> pred
+ //bf16 -> pred
+ def : Pat<(i1 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))),
----------------
tra wrote:
> nit. Needs space between the "//" and the comment itself. Same in the f32 block below.
Not done yet.
================
Comment at: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td:1271-1287
-def : Pat<(int_nvvm_ff2f16x2_rn Float32Regs:$a, Float32Regs:$b),
- (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
-def : Pat<(int_nvvm_ff2f16x2_rn_relu Float32Regs:$a, Float32Regs:$b),
- (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
-def : Pat<(int_nvvm_ff2f16x2_rz Float32Regs:$a, Float32Regs:$b),
- (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>;
-def : Pat<(int_nvvm_ff2f16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
----------------
Were these patterns removed intentionally? We still have intrinsics/builtins defined in llvm/include/llvm/IR/IntrinsicsNVVM.td and still need to lower them.
================
Comment at: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td:2199-2206
defm INT_PTX_LDU_G_v2f16_ELE
: VLDU_G_ELE_V2<"v2.b16 \t{{$dst1, $dst2}}, [$src];", Int16Regs>;
defm INT_PTX_LDU_G_v2f16x2_ELE
: VLDU_G_ELE_V2<"v2.b32 \t{{$dst1, $dst2}}, [$src];", Int32Regs>;
+defm INT_PTX_LDU_G_v2bf16_ELE
+ : VLDU_G_ELE_V2<"v2.b16 \t{{$dst1, $dst2}}, [$src];", Int16Regs>;
+defm INT_PTX_LDU_G_v2bf16x2_ELE
----------------
Another place I've missed. I think these are also no longer necessary.
================
Comment at: llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp:29-32
+static cl::opt<bool>
+ NoBF16Math("nvptx-no-bf16-math", cl::Hidden,
+ cl::desc("NVPTX Specific: Disable generation of bf16 math ops."),
+ cl::init(false));
----------------
I wonder if we need it at all for bf16.
The knob was needed for fp16 because some of the Pascal generation GPUs had very limited number of fp16 units (1:64 fp16:fp32 ratio on sm_61 vs 2:1 on sm_60 IIRC). While it could technically execute fp16 operations, it was much faster to promote them to fp32.
I do not know how well bf16 performs on consumer grade GPU variants. If it's limited, similar to sm_61, only then we'll still need this knob.
================
Comment at: llvm/test/CodeGen/NVPTX/bf16-instructions.ll:13
+define bfloat @test_fadd(bfloat %0, bfloat %1) {
+ %3 = fadd bfloat %0, %1
+ ret bfloat %3
----------------
We need more bf16 tests covering:
- bf16 constants.
- extracting high/low/both elements of v2bf16
- combining two bf16 scalars into v2bf16
- conversion bf16 <-> fp32
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D144911/new/
https://reviews.llvm.org/D144911
More information about the cfe-commits
mailing list