[llvm] [NVPTX] Lower bfloat16 add/mul/sub as fma on SM80 (PR #121065)

Justin Fargnoli via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 26 12:49:52 PST 2024


================
@@ -2490,6 +2500,62 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
   return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
 }
 
+static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
+  EVT VT = N->getValueType(0);
+  EVT NVT = MVT::f32;
+  if (VT.isVector()) {
+    NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
+  }
+  SDLoc DL(N);
+  SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
+  SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
+  SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
+  return DAG.getFPExtendOrRound(Res, DL, VT);
+}
+
+SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
+  // No fma.ftz for bf16, so fall back to promotion
+  if (useF32FTZ(DAG.getMachineFunction())) {
+    return PromoteBinOpToF32(Op.getNode(), DAG);
+  }
----------------
justinfargnoli wrote:

Additionally, it looks like `TargetLowering` is constructed on a per-module basis, so it would not make sense to check this in the constructor. 

https://github.com/llvm/llvm-project/pull/121065


More information about the llvm-commits mailing list