[llvm] [NVPTX] Select bfloat16 add/mul/sub as fma on SM80 (PR #121065)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 9 14:41:04 PST 2025


================
@@ -2450,6 +2457,66 @@ bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {
   return true;
 }
 
+// Select bf16/bf16v2 FADD, FSUB, FMUL as fma on targets with only fma
+bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
+  EVT VT = SDValue(N, 0).getValueType();
+  if (VT.getScalarType() != MVT::bf16)
+    return false;
+
+  const NVPTXSubtarget *STI = TM.getSubtargetImpl();
+  const bool IsNativelySupported =
+      STI->getSmVersion() >= 90 && STI->getPTXVersion() >= 78;
+  if (IsNativelySupported)
+    return false;
+
+  assert(VT == MVT::bf16 || VT == MVT::v2bf16);
----------------
peterbell10 wrote:

I'm asserting it's either a scalar or vector of length 2.

https://github.com/llvm/llvm-project/pull/121065


More information about the llvm-commits mailing list