[llvm] [NVPTX] Lower bfloat16 add/mul/sub as fma on SM80 (PR #121065)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 26 13:43:54 PST 2024


================
@@ -853,6 +853,16 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
       AddPromotedToType(Op, MVT::bf16, MVT::f32);
   }
 
+  // Lower bf16 add/mul/sub as fma when it avoids promotion
+  for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
+    for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
+      if (getOperationAction(Op, VT) != Legal &&
+          getOperationAction(ISD::FMA, VT) == Legal) {
----------------
peterbell10 wrote:

> This is cumbersome, we usually don't write legalizer rules in terms of other legalizer rules.

I could easily write it in terms of SM and PTX version ranges instead, I just felt this more directly expressed the rationale behind the version range.

> I think you'd be best off just putting this logic into the default Expand action for add/fmul/fsub. If the FMA is legal, you emit the appropriate sequence before falling back to the default libcall expansion. Then you shouldn't need to touch the target rules here

I'm not sure this makes sense. The FTZ logic is target specific and we also want to fallback to promotion, not a libcall here. 

https://github.com/llvm/llvm-project/pull/121065


More information about the llvm-commits mailing list