[llvm] [NVPTX] Lower bfloat16 add/mul/sub as fma on SM80 (PR #121065)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Wed Dec 25 22:12:20 PST 2024
================
@@ -2490,6 +2500,62 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
}
+static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
+ EVT VT = N->getValueType(0);
+ EVT NVT = MVT::f32;
+ if (VT.isVector()) {
+ NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
+ }
+ SDLoc DL(N);
+ SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
+ SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
+ SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
+ return DAG.getFPExtendOrRound(Res, DL, VT);
+}
+
+SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
+ // No fma.ftz for bf16, so fall back to promotion
+ if (useF32FTZ(DAG.getMachineFunction())) {
+ return PromoteBinOpToF32(Op.getNode(), DAG);
+ }
+
+ // FADD(a, b) -> FMA(a, 1.0, b)
+ SDLoc DL(Op);
+ auto VT = Op.getValueType();
+ auto One = DAG.getConstantFP(1.0, DL, VT);
+ SmallVector<SDValue, 3> Operands{Op->getOperand(0), One, Op->getOperand(1)};
+ return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
+SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
+ // No fma.ftz for bf16, so fall back to promotion
+ if (useF32FTZ(DAG.getMachineFunction())) {
+ return PromoteBinOpToF32(Op.getNode(), DAG);
+ }
+
+ // FSUB(a, b) -> FMA(b, -1.0, a)
+ SDLoc DL(Op);
+ auto VT = Op.getValueType();
+ auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
+ SmallVector<SDValue, 3> Operands{Op->getOperand(1), NegOne,
+ Op->getOperand(0)};
+ return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
+SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
+ // No fma.ftz for bf16, so fall back to promotion
+ if (useF32FTZ(DAG.getMachineFunction())) {
+ return PromoteBinOpToF32(Op.getNode(), DAG);
+ }
+
+ // FMUL(a, b) -> FMA(a, b, 0.0)
+ SDLoc DL(Op);
+ auto VT = Op.getValueType();
+ auto Zero = DAG.getConstantFP(0.0, DL, VT);
----------------
arsenm wrote:
I think this needs to be -0
https://github.com/llvm/llvm-project/pull/121065
More information about the llvm-commits
mailing list