[llvm] Correctly round FP -> BF16 when SDAG expands such nodes (PR #82399)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 21 02:16:21 PST 2024


================
@@ -3219,8 +3219,98 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
   case ISD::FP_ROUND: {
     EVT VT = Node->getValueType(0);
     if (VT.getScalarType() == MVT::bf16) {
-      Results.push_back(
-          DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
+      if (Node->getConstantOperandVal(1) == 1) {
+        Results.push_back(
+            DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
+        break;
+      }
+      SDValue Op = Node->getOperand(0);
+      SDValue IsNaN = DAG.getSetCC(dl, getSetCCResultType(Op.getValueType()),
+                                   Op, Op, ISD::SETUO);
+      if (Op.getValueType() != MVT::f32) {
+        // We are rounding binary64/binary128 -> binary32 -> bfloat16. This
+        // can induce double-rounding which may alter the results. We can
+        // correct for this using a trick explained in: Boldo, Sylvie, and
+        // Guillaume Melquiond. "When double rounding is odd." 17th IMACS
+        // World Congress. 2005.
+        FloatSignAsInt ValueAsInt;
+        getSignAsIntValue(ValueAsInt, dl, Op);
+        EVT WideIntVT = ValueAsInt.IntValue.getValueType();
+        SDValue SignMask = DAG.getConstant(ValueAsInt.SignMask, dl, WideIntVT);
+        SDValue SignBit =
+            DAG.getNode(ISD::AND, dl, WideIntVT, ValueAsInt.IntValue, SignMask);
+        SDValue AbsWide;
+        if (TLI.isOperationLegalOrCustom(ISD::FABS, ValueAsInt.FloatVT)) {
+          AbsWide = DAG.getNode(ISD::FABS, dl, ValueAsInt.FloatVT, Op);
+        } else {
+          SDValue ClearSignMask =
+              DAG.getConstant(~ValueAsInt.SignMask, dl, WideIntVT);
+          SDValue ClearedSign = DAG.getNode(ISD::AND, dl, WideIntVT,
+                                            ValueAsInt.IntValue, ClearSignMask);
+          AbsWide = modifySignAsInt(ValueAsInt, dl, ClearedSign);
+        }
+        SDValue AbsNarrow =
+            DAG.getNode(ISD::FP_ROUND, dl, MVT::f32, AbsWide,
+                        DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
+        SDValue AbsNarrowAsWide =
+            DAG.getNode(ISD::FP_EXTEND, dl, ValueAsInt.FloatVT, AbsNarrow);
+
+        // We can keep the narrow value as-is if narrowing was exact (no
+        // rounding error), the wide value was NaN (the narrow value is also
+        // NaN and should be preserved) or if we rounded to the odd value.
+        SDValue NarrowBits = DAG.getNode(ISD::BITCAST, dl, MVT::i32, AbsNarrow);
+        SDValue One = DAG.getConstant(1, dl, MVT::i32);
+        SDValue NegativeOne = DAG.getConstant(-1, dl, MVT::i32);
+        SDValue And = DAG.getNode(ISD::AND, dl, MVT::i32, NarrowBits, One);
+        EVT I32CCVT = getSetCCResultType(And.getValueType());
+        SDValue Zero = DAG.getConstant(0, dl, MVT::i32);
+        SDValue AlreadyOdd = DAG.getSetCC(dl, I32CCVT, And, Zero, ISD::SETNE);
+
+        EVT WideSetCCVT = getSetCCResultType(AbsWide.getValueType());
+        SDValue KeepNarrow = DAG.getSetCC(dl, WideSetCCVT, AbsWide,
+                                          AbsNarrowAsWide, ISD::SETUEQ);
+        KeepNarrow =
+            DAG.getNode(ISD::OR, dl, WideSetCCVT, KeepNarrow, AlreadyOdd);
+        // We morally performed a round-down if `abs_narrow` is smaller than
+        // `abs_wide`.
+        SDValue NarrowIsRd = DAG.getSetCC(dl, WideSetCCVT, AbsWide,
+                                          AbsNarrowAsWide, ISD::SETOGT);
+        // If the narrow value is odd or exact, pick it.
+        // Otherwise, narrow is even and corresponds to either the rounded-up
+        // or rounded-down value. If narrow is the rounded-down value, we want
+        // the rounded-up value as it will be odd.
+        SDValue Adjust =
+            DAG.getSelect(dl, MVT::i32, NarrowIsRd, One, NegativeOne);
+        Adjust = DAG.getSelect(dl, MVT::i32, KeepNarrow, Zero, Adjust);
+        int ShiftAmount = ValueAsInt.SignBit - 31;
+        SDValue ShiftCnst = DAG.getConstant(
+            ShiftAmount, dl,
+            TLI.getShiftAmountTy(WideIntVT, DAG.getDataLayout()));
+        SignBit = DAG.getNode(ISD::SRL, dl, WideIntVT, SignBit, ShiftCnst);
+        SignBit = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, SignBit);
+        Op = DAG.getNode(ISD::OR, dl, MVT::i32, Adjust, SignBit);
+      } else {
+        Op = DAG.getNode(ISD::BITCAST, dl, MVT::i32, Op);
+      }
+
+      SDValue One = DAG.getConstant(1, dl, MVT::i32);
+      SDValue Lsb = DAG.getNode(
+          ISD::SRL, dl, MVT::i32, Op,
+          DAG.getConstant(16, dl,
+                          TLI.getShiftAmountTy(MVT::i32, DAG.getDataLayout())));
----------------
RKSimon wrote:

Use DAG.getShiftAmountConstant?

https://github.com/llvm/llvm-project/pull/82399


More information about the llvm-commits mailing list