[llvm] [NVPTX] Lower bfloat16 add/mul/sub as fma on SM80 (PR #121065)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Dec 25 18:21:39 PST 2024
================
@@ -2490,6 +2500,62 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
}
+static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
+ EVT VT = N->getValueType(0);
+ EVT NVT = MVT::f32;
+ if (VT.isVector()) {
+ NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
+ }
+ SDLoc DL(N);
+ SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
+ SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
+ SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
+ return DAG.getFPExtendOrRound(Res, DL, VT);
+}
+
+SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
+ // No fma.ftz for bf16, so fall back to promotion
+ if (useF32FTZ(DAG.getMachineFunction())) {
+ return PromoteBinOpToF32(Op.getNode(), DAG);
+ }
----------------
peterbell10 wrote:
I don't think we have access to the machine function inside the constructor, so we can't tell if FTZ is enabled or not.
https://github.com/llvm/llvm-project/pull/121065
More information about the llvm-commits
mailing list