[llvm] AMDGPU: Add round-to-odd rounding during f64 to bf16 conversion (PR #133995)

Changpeng Fang via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 1 22:28:41 PDT 2025


================
@@ -6888,23 +6887,33 @@ SDValue SITargetLowering::getFPExtOrFPRound(SelectionDAG &DAG, SDValue Op,
 }
 
 SDValue SITargetLowering::lowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
-  assert(Op.getValueType() == MVT::f16 &&
-         "Do not know how to custom lower FP_ROUND for non-f16 type");
-
   SDValue Src = Op.getOperand(0);
   EVT SrcVT = Src.getValueType();
-  if (SrcVT != MVT::f64)
-    return Op;
-
-  // TODO: Handle strictfp
-  if (Op.getOpcode() != ISD::FP_ROUND)
+  if (SrcVT.getScalarType() != MVT::f64)
     return Op;
 
+  EVT DstVT = Op.getValueType();
   SDLoc DL(Op);
+  if (DstVT == MVT::f16) {
+    // TODO: Handle strictfp
+    if (Op.getOpcode() != ISD::FP_ROUND)
+      return Op;
+
+    SDValue FpToFp16 = DAG.getNode(ISD::FP_TO_FP16, DL, MVT::i32, Src);
+    SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToFp16);
+    return DAG.getNode(ISD::BITCAST, DL, MVT::f16, Trunc);
+  }
+
+  assert (DstVT.getScalarType() == MVT::bf16 &&
+          "custom lower FP_ROUND for f16 or bf16");
+  assert (Subtarget->hasBF16ConversionInsts() && "f32 -> bf16 is legal");
 
-  SDValue FpToFp16 = DAG.getNode(ISD::FP_TO_FP16, DL, MVT::i32, Src);
-  SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToFp16);
-  return DAG.getNode(ISD::BITCAST, DL, MVT::f16, Trunc);
+  // Round-inexact-to-odd f64 to f32, then do the final rounding using the
+  // hardware f32 -> bf16 instruction.
+  EVT F32VT = SrcVT.isVector() ? SrcVT.changeVectorElementType(MVT::f32) :
+                                 MVT::f32;
+  SDValue Rod = expandRoundInexactToOdd(F32VT, Src, DL, DAG);
----------------
changpeng wrote:

No. If we return the original node, it will generate the two conversions and have the double rounding issue.
Also, if we setOperationAction to "Expand", it will do the Round-inexact-to-odd, but won't generate the desired v_cvt_pk_bf16_f32 instruction for f32->bf16. 

https://github.com/llvm/llvm-project/pull/133995


More information about the llvm-commits mailing list