[llvm] [NVPTX] Select bfloat16 add/mul/sub as fma on SM80 (PR #121065)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 9 14:28:23 PST 2025


================
@@ -2450,6 +2457,66 @@ bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {
   return true;
 }
 
+// Select bf16/bf16v2 FADD, FSUB, FMUL as fma on targets with only fma
+bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
+  EVT VT = SDValue(N, 0).getValueType();
+  if (VT.getScalarType() != MVT::bf16)
+    return false;
+
+  const NVPTXSubtarget *STI = TM.getSubtargetImpl();
+  const bool IsNativelySupported =
+      STI->getSmVersion() >= 90 && STI->getPTXVersion() >= 78;
+  if (IsNativelySupported)
+    return false;
+
+  assert(VT == MVT::bf16 || VT == MVT::v2bf16);
+  const bool IsVec = VT == MVT::v2bf16;
+  SDLoc DL(N);
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  SmallVector<SDValue, 3> Operands;
+  auto GetConstant = [&](float Value) -> SDValue {
+    APFloat APF(Value);
+    bool LosesInfo;
+    APF.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &LosesInfo);
+    assert(!LosesInfo);
+    if (IsVec) {
+      auto API = APF.bitcastToAPInt();
+      API = API.concat(API);
+      auto Const = CurDAG->getTargetConstant(API, DL, MVT::i32);
+      return SDValue(CurDAG->getMachineNode(NVPTX::IMOV32ri, DL, VT, Const), 0);
----------------
peterbell10 wrote:

No there doesn't seem to be any support for bf16 immediate values, ptxas complains
```
ptxas /tmp/tmplmhxk1av.ptx, line 79; error   : Arguments mismatch for instruction 'fma'
ptxas fatal   : Ptx assembly aborted due to errors
```

https://github.com/llvm/llvm-project/pull/121065


More information about the llvm-commits mailing list