[PATCH] D144911: adding bf16 support to NVPTX

Artem Belevich via Phabricator via cfe-commits cfe-commits at lists.llvm.org
Thu Jun 8 14:37:01 PDT 2023


tra added a comment.

Overall looks good with few minor nits and a couple of questions.



================
Comment at: llvm/include/llvm/IR/IntrinsicsNVVM.td:604
       def int_nvvm_f # operation # variant :
         ClangBuiltin<!strconcat("__nvvm_f", operation, variant)>,
         DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_i16_ty, llvm_i16_ty],
----------------
tra wrote:
> tra wrote:
> > Availability of these new instructions is conditional on specific CUDA version and the GPU variant we're compiling for,
> > Such builtins are normally implemented on the clang size as a `TARGET_BUILTIN()` with appropriate constraints.
> > 
> > Without that `ClangBuiltin` may automatically add enough glue to make them available in clang unconditionally, which would result in compiler crashing if a user tries to use one of those builtins with a wrong GPU or CUDA version. We want to emit a diagnostics, not cause a compiler crash.
> > 
> > Usually such related LLVM and clang changes should be part of the same patch.
> > 
> > This applies to the new intrinsic variants added below, too.
> I do not think it's is done. 
> 
> Can you check what happens if you try to call any of bf16 builtins while compiling for sm_60? Ideally we should produce a sensible error that the builtin is not available.
> 
> I suspect we will fail in LLVM when we'll fail to lower the intrinsic, ot in nvptx if we've managed to lower it to an instruction unsupported by sm_60.
OK. We'll leave conditional clang builtin handling to be fixed separately as it's not directly related to this patch.


================
Comment at: llvm/include/llvm/IR/IntrinsicsNVVM.td:878
     def int_nvvm_fma # variant : ClangBuiltin<!strconcat("__nvvm_fma", variant)>,
-      DefaultAttrsIntrinsic<[llvm_i16_ty],
-        [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty],
+      DefaultAttrsIntrinsic<[llvm_bfloat_ty],
+        [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty],
----------------
This changes signatures of existing intrinsics and builtins. While the change is correct, we should at least check that MLIR tests are still passing.



================
Comment at: llvm/include/llvm/IR/IntrinsicsNVVM.td:1244-1251
   def int_nvvm_ff2bf16x2_rn : ClangBuiltin<"__nvvm_ff2bf16x2_rn">,
        Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
   def int_nvvm_ff2bf16x2_rn_relu : ClangBuiltin<"__nvvm_ff2bf16x2_rn_relu">,
       Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
   def int_nvvm_ff2bf16x2_rz : ClangBuiltin<"__nvvm_ff2bf16x2_rz">,
       Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
   def int_nvvm_ff2bf16x2_rz_relu : ClangBuiltin<"__nvvm_ff2bf16x2_rz_relu">,
----------------
We've removed the patterns matching these intrinsics in lib/Target/NVPTX/NVPTXIntrinsics.td so there's nothing to lower them to an instruction now. Was that intentional?


================
Comment at: llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp:64-69
+  case 9:
     OS << "%h";
     break;
   case 8:
+  case 10:
     OS << "%hh";
----------------
Looks like I've forgot to remove those cases in my regclass patch. Will fix it shortly.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:640
                          ISD::FROUNDEVEN, ISD::FTRUNC}) {
+    setOperationAction(Op, MVT::bf16, Legal);
     setOperationAction(Op, MVT::f16, Legal);
----------------
Nit: sometimes bf16 variants are added above fp16 variants, sometimes after. It would be nice to do it consistently. I guess we should just do a cleanup patch sorting these blocks in type order.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:2514-2516
+  if ((Isv2f16Orv2bf16Type(VT.getSimpleVT())) &&
       !allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
                                       VT, *Store->getMemOperand()))
----------------
Unnecessary  `()`around `Isv2f16Orv2bf16Type(VT.getSimpleVT())`


================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:2601
       // store them with st.v4.b32.
-      assert((EltVT == MVT::f16 || EltVT == MVT::bf16) &&
+      assert((Isf16Orbf16Type(EltVT.getSimpleVT())) &&
              "Wrong type for the vector.");
----------------
Ditto.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:1316-1326
 defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>;
 defm FMA16     : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>;
 defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
 defm FMA16x2     : FMA_F16<"fma.rn.f16x2", v2f16, Int32Regs, True>;
+defm BFMA16_ftz : FMA_BF16<"fma.rn.ftz.bf16", bf16, Int16Regs, doF32FTZ>;
+defm BFMA16     : FMA_BF16<"fma.rn.bf16", bf16, Int16Regs, True>;
+defm BFMA16x2_ftz : FMA_BF16<"fma.rn.ftz.bf16x2", v2bf16, Int32Regs, doF32FTZ>;
----------------
Nit: align ':' across the block.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:1892-1893
                            "mov.b16 \t$dst, $src;", []>;
+  def BFMOV16rr : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
+                           "mov.b16 \t$dst, $src;", []>;
   def FMOV32rr : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
----------------
we now have `IMOVB16rr` for this.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:2811-2814
   defm LDV_f16 : LD_VEC<Int16Regs>;
   defm LDV_f16x2 : LD_VEC<Int32Regs>;
+  defm LDV_bf16 : LD_VEC<Int16Regs>;
+  defm LDV_bf16x2 : LD_VEC<Int32Regs>;
----------------
Those should no longer be necessary. LDV_i16/LDV_i32 should do the job.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:2908-2911
   defm STV_f16 : ST_VEC<Int16Regs>;
   defm STV_f16x2 : ST_VEC<Int32Regs>;
+  defm STV_bf16 : ST_VEC<Int16Regs>;
+  defm STV_bf16x2 : ST_VEC<Int32Regs>;
----------------
Ditto.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:659-666
 multiclass CVT_FROM_FLOAT_SM80<string FromName, RegisterClass RC> {
     def _f32 :
       NVPTXInst<(outs RC:$dst),
                 (ins Float32Regs:$src, CvtMode:$mode),
                 !strconcat("cvt${mode:base}${mode:relu}.",
                 FromName, ".f32 \t$dst, $src;"), []>,
                 Requires<[hasPTX70, hasSM80]>;
----------------
tra wrote:
> I think this multiclass can be deleted now.
Unused `CVT_FROM_FLOAT_SM80` should be removed, still.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:2094
 
-  // f32 -> pred
+  //bf16 -> pred
+  def : Pat<(i1 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))),
----------------
tra wrote:
> nit. Needs space between the "//" and the comment itself. Same in the f32 block below.
Not done yet.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td:1271-1287
-def : Pat<(int_nvvm_ff2f16x2_rn Float32Regs:$a, Float32Regs:$b),
-          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
-def : Pat<(int_nvvm_ff2f16x2_rn_relu Float32Regs:$a, Float32Regs:$b),
-          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
-def : Pat<(int_nvvm_ff2f16x2_rz Float32Regs:$a, Float32Regs:$b),
-          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>;
-def : Pat<(int_nvvm_ff2f16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
----------------
Were these patterns removed intentionally? We still have intrinsics/builtins defined in llvm/include/llvm/IR/IntrinsicsNVVM.td and still need to lower them.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td:2199-2206
 defm INT_PTX_LDU_G_v2f16_ELE
   : VLDU_G_ELE_V2<"v2.b16 \t{{$dst1, $dst2}}, [$src];", Int16Regs>;
 defm INT_PTX_LDU_G_v2f16x2_ELE
   : VLDU_G_ELE_V2<"v2.b32 \t{{$dst1, $dst2}}, [$src];", Int32Regs>;
+defm INT_PTX_LDU_G_v2bf16_ELE
+  : VLDU_G_ELE_V2<"v2.b16 \t{{$dst1, $dst2}}, [$src];", Int16Regs>;
+defm INT_PTX_LDU_G_v2bf16x2_ELE
----------------
Another place I've missed. I think these are also no longer necessary.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp:29-32
+static cl::opt<bool>
+    NoBF16Math("nvptx-no-bf16-math", cl::Hidden,
+               cl::desc("NVPTX Specific: Disable generation of bf16 math ops."),
+               cl::init(false));
----------------
I wonder if we need it at all for bf16.


The knob was needed for fp16 because some of the Pascal generation GPUs had very limited number of fp16 units (1:64 fp16:fp32 ratio on sm_61 vs 2:1 on sm_60 IIRC). While it could technically execute fp16 operations, it was much faster to promote them to fp32.

I do not know how well bf16 performs on consumer grade GPU variants. If it's limited, similar to sm_61, only then we'll still need this knob.


================
Comment at: llvm/test/CodeGen/NVPTX/bf16-instructions.ll:13
+define bfloat @test_fadd(bfloat %0, bfloat %1) {
+  %3 = fadd bfloat %0, %1
+  ret bfloat %3
----------------
We need more bf16 tests covering:
- bf16 constants.
- extracting high/low/both elements of v2bf16
- combining two bf16 scalars into v2bf16
- conversion bf16 <-> fp32


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D144911/new/

https://reviews.llvm.org/D144911



More information about the cfe-commits mailing list