[PATCH] D144911: adding bf16 support to NVPTX

Artem Belevich via Phabricator via cfe-commits cfe-commits at lists.llvm.org
Mon Jun 12 16:35:53 PDT 2023


tra added a comment.

Almost there. Just few cosmetic nits remaining.



================
Comment at: llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp:64-69
+  case 9:
     OS << "%h";
     break;
   case 8:
+  case 10:
     OS << "%hh";
----------------
tra wrote:
> Looks like I've forgot to remove those cases in my regclass patch. Will fix it shortly.
Still not fixed.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp:632-634
+  for (const auto &VT : {MVT::bf16, MVT::v2bf16})
+    setOperationAction(ISD::FNEG, VT,
+                       IsBFP16FP16x2NegAvailable ? Legal : Expand);
----------------
This could be just
```
setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand);
setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
```



================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:159
 def useFP16Math: Predicate<"Subtarget->allowFP16Math()">;
+def useBFP16Math: Predicate<"Subtarget->allowBF16Math()">;
 
----------------
Nit: `useBF16Math` as in fp16 -> bf16.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td:1118
+                [(set RC:$dst, (fneg (T RC:$src)))]>,
+                Requires<[useFP16Math, hasPTX<70>, hasSM<80>, Pred]>;
+def BFNEG16_ftz   : FNEG_BF16_F16X2<"neg.ftz.bf16", bf16, Int16Regs, doF32FTZ>;
----------------
I think you need to use `useBF16Math` here.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp:68
+
+bool NVPTXSubtarget::allowBF16Math() const { return hasBF16Math(); }
----------------
We do not need `allowBF16Math` any more.  Just use `hasBF16Math()`.


================
Comment at: llvm/lib/Target/NVPTX/NVPTXSubtarget.h:81
   bool allowFP16Math() const;
+  bool allowBF16Math() const;
   bool hasMaskOperator() const { return PTXVersion >= 71; }
----------------
Not needed.


================
Comment at: llvm/test/CodeGen/NVPTX/bf16-instructions.ll:16
+define bfloat @test_fadd(bfloat %0, bfloat %1) {
+  %3 = fadd bfloat %0, %1
+  ret bfloat %3
----------------
Another test that would be useful is for `fadd bfloat %0, 1.0`



Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D144911/new/

https://reviews.llvm.org/D144911



More information about the cfe-commits mailing list