[llvm] [NVPTX] Custom lower integer<->bf16 conversions for sm_80 (PR #74827)

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 11 10:56:27 PST 2023


================
@@ -2580,6 +2586,37 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
   return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
 }
 
+SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
+                                            SelectionDAG &DAG) const {
+  // sm_90 has instructions for bf16 conversions, sm_80 only has f32 -> bf16.
+  if (Op.getValueType() == MVT::bf16 &&
+      (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
+    SDLoc Loc(Op);
+    return DAG.getNode(
+        ISD::FP_ROUND, Loc, MVT::bf16,
+        DAG.getNode(Op.getOpcode(), Loc, MVT::f32, Op.getOperand(0)),
+        DAG.getIntPtrConstant(0, Loc));
+  }
+
+  // Everything else is considered legal.
+  return Op;
+}
+
+SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
+                                            SelectionDAG &DAG) const {
+  // sm_90 has instructions for bf16 conversions, sm_80 only has f32.
+  if (Op.getOperand(0).getValueType() == MVT::bf16 &&
+      (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
+    SDLoc Loc(Op);
+    return DAG.getNode(
+        Op.getOpcode(), Loc, Op.getValueType(),
+        DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f32, Op.getOperand(0)));
----------------
Artem-B wrote:

Same here, `FP_TO_BF16`  exists. It would be interesting to check what LLVM will actually generate and then check  which variant produces a better SASS on sm_80 and sm_90.

https://github.com/llvm/llvm-project/pull/74827


More information about the llvm-commits mailing list