[llvm] [NVPTX] Select bfloat16 add/mul/sub as fma on SM80 (PR #121065)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 9 14:47:31 PST 2025
================
@@ -2450,6 +2457,66 @@ bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {
return true;
}
+// Select bf16/bf16v2 FADD, FSUB, FMUL as fma on targets with only fma
+bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
+ EVT VT = SDValue(N, 0).getValueType();
+ if (VT.getScalarType() != MVT::bf16)
+ return false;
+
+ const NVPTXSubtarget *STI = TM.getSubtargetImpl();
+ const bool IsNativelySupported =
+ STI->getSmVersion() >= 90 && STI->getPTXVersion() >= 78;
+ if (IsNativelySupported)
+ return false;
+
+ assert(VT == MVT::bf16 || VT == MVT::v2bf16);
+ const bool IsVec = VT == MVT::v2bf16;
+ SDLoc DL(N);
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SmallVector<SDValue, 3> Operands;
+ auto GetConstant = [&](float Value) -> SDValue {
+ APFloat APF(Value);
+ bool LosesInfo;
+ APF.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &LosesInfo);
+ assert(!LosesInfo);
+ if (IsVec) {
+ auto API = APF.bitcastToAPInt();
+ API = API.concat(API);
+ auto Const = CurDAG->getTargetConstant(API, DL, MVT::i32);
+ return SDValue(CurDAG->getMachineNode(NVPTX::IMOV32ri, DL, VT, Const), 0);
+ }
+ auto Const = CurDAG->getTargetConstantFP(APF, DL, VT);
+ return SDValue(CurDAG->getMachineNode(NVPTX::BFMOV16ri, DL, VT, Const), 0);
+ };
+
+ switch (N->getOpcode()) {
+ case ISD::FADD: {
----------------
peterbell10 wrote:
Removed
https://github.com/llvm/llvm-project/pull/121065
More information about the llvm-commits
mailing list