[llvm] [NVPTX] Select bfloat16 add/mul/sub as fma on SM80 (PR #121065)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 9 14:28:23 PST 2025
================
@@ -2450,6 +2457,66 @@ bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {
return true;
}
+// Select bf16/bf16v2 FADD, FSUB, FMUL as fma on targets with only fma
+bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
+ EVT VT = SDValue(N, 0).getValueType();
+ if (VT.getScalarType() != MVT::bf16)
+ return false;
+
+ const NVPTXSubtarget *STI = TM.getSubtargetImpl();
+ const bool IsNativelySupported =
+ STI->getSmVersion() >= 90 && STI->getPTXVersion() >= 78;
+ if (IsNativelySupported)
+ return false;
+
+ assert(VT == MVT::bf16 || VT == MVT::v2bf16);
+ const bool IsVec = VT == MVT::v2bf16;
+ SDLoc DL(N);
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SmallVector<SDValue, 3> Operands;
+ auto GetConstant = [&](float Value) -> SDValue {
+ APFloat APF(Value);
+ bool LosesInfo;
+ APF.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &LosesInfo);
+ assert(!LosesInfo);
+ if (IsVec) {
+ auto API = APF.bitcastToAPInt();
+ API = API.concat(API);
+ auto Const = CurDAG->getTargetConstant(API, DL, MVT::i32);
+ return SDValue(CurDAG->getMachineNode(NVPTX::IMOV32ri, DL, VT, Const), 0);
----------------
peterbell10 wrote:
No there doesn't seem to be any support for bf16 immediate values, ptxas complains
```
ptxas /tmp/tmplmhxk1av.ptx, line 79; error : Arguments mismatch for instruction 'fma'
ptxas fatal : Ptx assembly aborted due to errors
```
https://github.com/llvm/llvm-project/pull/121065
More information about the llvm-commits
mailing list