[llvm] [NVPTX] Lower bfloat16 add/mul/sub as fma on SM80 (PR #121065)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 25 22:12:20 PST 2024


================
@@ -2490,6 +2500,62 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
   return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
 }
 
+static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
+  EVT VT = N->getValueType(0);
+  EVT NVT = MVT::f32;
+  if (VT.isVector()) {
+    NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
+  }
+  SDLoc DL(N);
+  SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
+  SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
+  SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
+  return DAG.getFPExtendOrRound(Res, DL, VT);
+}
+
+SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
+  // No fma.ftz for bf16, so fall back to promotion
+  if (useF32FTZ(DAG.getMachineFunction())) {
+    return PromoteBinOpToF32(Op.getNode(), DAG);
+  }
+
+  // FADD(a, b) -> FMA(a, 1.0, b)
+  SDLoc DL(Op);
+  auto VT = Op.getValueType();
+  auto One = DAG.getConstantFP(1.0, DL, VT);
+  SmallVector<SDValue, 3> Operands{Op->getOperand(0), One, Op->getOperand(1)};
+  return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
+SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
+  // No fma.ftz for bf16, so fall back to promotion
+  if (useF32FTZ(DAG.getMachineFunction())) {
+    return PromoteBinOpToF32(Op.getNode(), DAG);
+  }
+
+  // FSUB(a, b) -> FMA(b, -1.0, a)
+  SDLoc DL(Op);
+  auto VT = Op.getValueType();
+  auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
+  SmallVector<SDValue, 3> Operands{Op->getOperand(1), NegOne,
+                                   Op->getOperand(0)};
+  return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
+SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
+  // No fma.ftz for bf16, so fall back to promotion
+  if (useF32FTZ(DAG.getMachineFunction())) {
+    return PromoteBinOpToF32(Op.getNode(), DAG);
+  }
+
+  // FMUL(a, b) -> FMA(a, b, 0.0)
+  SDLoc DL(Op);
+  auto VT = Op.getValueType();
+  auto Zero = DAG.getConstantFP(0.0, DL, VT);
----------------
arsenm wrote:

I think this needs to be -0 

https://github.com/llvm/llvm-project/pull/121065


More information about the llvm-commits mailing list