[llvm] Correctly round FP -> BF16 when SDAG expands such nodes (PR #82399)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 21 02:16:20 PST 2024
================
@@ -3219,8 +3219,98 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
case ISD::FP_ROUND: {
EVT VT = Node->getValueType(0);
if (VT.getScalarType() == MVT::bf16) {
- Results.push_back(
- DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
+ if (Node->getConstantOperandVal(1) == 1) {
+ Results.push_back(
+ DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
+ break;
+ }
+ SDValue Op = Node->getOperand(0);
+ SDValue IsNaN = DAG.getSetCC(dl, getSetCCResultType(Op.getValueType()),
+ Op, Op, ISD::SETUO);
+ if (Op.getValueType() != MVT::f32) {
+ // We are rounding binary64/binary128 -> binary32 -> bfloat16. This
+ // can induce double-rounding which may alter the results. We can
+ // correct for this using a trick explained in: Boldo, Sylvie, and
+ // Guillaume Melquiond. "When double rounding is odd." 17th IMACS
+ // World Congress. 2005.
+ FloatSignAsInt ValueAsInt;
+ getSignAsIntValue(ValueAsInt, dl, Op);
+ EVT WideIntVT = ValueAsInt.IntValue.getValueType();
+ SDValue SignMask = DAG.getConstant(ValueAsInt.SignMask, dl, WideIntVT);
+ SDValue SignBit =
+ DAG.getNode(ISD::AND, dl, WideIntVT, ValueAsInt.IntValue, SignMask);
+ SDValue AbsWide;
+ if (TLI.isOperationLegalOrCustom(ISD::FABS, ValueAsInt.FloatVT)) {
+ AbsWide = DAG.getNode(ISD::FABS, dl, ValueAsInt.FloatVT, Op);
+ } else {
+ SDValue ClearSignMask =
+ DAG.getConstant(~ValueAsInt.SignMask, dl, WideIntVT);
+ SDValue ClearedSign = DAG.getNode(ISD::AND, dl, WideIntVT,
+ ValueAsInt.IntValue, ClearSignMask);
+ AbsWide = modifySignAsInt(ValueAsInt, dl, ClearedSign);
+ }
+ SDValue AbsNarrow =
+ DAG.getNode(ISD::FP_ROUND, dl, MVT::f32, AbsWide,
+ DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
+ SDValue AbsNarrowAsWide =
+ DAG.getNode(ISD::FP_EXTEND, dl, ValueAsInt.FloatVT, AbsNarrow);
+
+ // We can keep the narrow value as-is if narrowing was exact (no
+ // rounding error), the wide value was NaN (the narrow value is also
+ // NaN and should be preserved) or if we rounded to the odd value.
+ SDValue NarrowBits = DAG.getNode(ISD::BITCAST, dl, MVT::i32, AbsNarrow);
+ SDValue One = DAG.getConstant(1, dl, MVT::i32);
+ SDValue NegativeOne = DAG.getConstant(-1, dl, MVT::i32);
+ SDValue And = DAG.getNode(ISD::AND, dl, MVT::i32, NarrowBits, One);
+ EVT I32CCVT = getSetCCResultType(And.getValueType());
+ SDValue Zero = DAG.getConstant(0, dl, MVT::i32);
+ SDValue AlreadyOdd = DAG.getSetCC(dl, I32CCVT, And, Zero, ISD::SETNE);
+
+ EVT WideSetCCVT = getSetCCResultType(AbsWide.getValueType());
+ SDValue KeepNarrow = DAG.getSetCC(dl, WideSetCCVT, AbsWide,
+ AbsNarrowAsWide, ISD::SETUEQ);
+ KeepNarrow =
+ DAG.getNode(ISD::OR, dl, WideSetCCVT, KeepNarrow, AlreadyOdd);
+ // We morally performed a round-down if `abs_narrow` is smaller than
+ // `abs_wide`.
+ SDValue NarrowIsRd = DAG.getSetCC(dl, WideSetCCVT, AbsWide,
+ AbsNarrowAsWide, ISD::SETOGT);
+ // If the narrow value is odd or exact, pick it.
+ // Otherwise, narrow is even and corresponds to either the rounded-up
+ // or rounded-down value. If narrow is the rounded-down value, we want
+ // the rounded-up value as it will be odd.
+ SDValue Adjust =
+ DAG.getSelect(dl, MVT::i32, NarrowIsRd, One, NegativeOne);
+ Adjust = DAG.getSelect(dl, MVT::i32, KeepNarrow, Zero, Adjust);
+ int ShiftAmount = ValueAsInt.SignBit - 31;
+ SDValue ShiftCnst = DAG.getConstant(
+ ShiftAmount, dl,
+ TLI.getShiftAmountTy(WideIntVT, DAG.getDataLayout()));
+ SignBit = DAG.getNode(ISD::SRL, dl, WideIntVT, SignBit, ShiftCnst);
+ SignBit = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, SignBit);
+ Op = DAG.getNode(ISD::OR, dl, MVT::i32, Adjust, SignBit);
+ } else {
+ Op = DAG.getNode(ISD::BITCAST, dl, MVT::i32, Op);
+ }
+
+ SDValue One = DAG.getConstant(1, dl, MVT::i32);
+ SDValue Lsb = DAG.getNode(
+ ISD::SRL, dl, MVT::i32, Op,
+ DAG.getConstant(16, dl,
+ TLI.getShiftAmountTy(MVT::i32, DAG.getDataLayout())));
+ Lsb = DAG.getNode(ISD::AND, dl, MVT::i32, Lsb, One);
+ SDValue RoundingBias = DAG.getNode(
+ ISD::ADD, dl, MVT::i32, DAG.getConstant(0x7fff, dl, MVT::i32), Lsb);
+ SDValue Add = DAG.getNode(ISD::ADD, dl, MVT::i32, Op, RoundingBias);
+ Op = DAG.getNode(
+ ISD::SRL, dl, MVT::i32, Add,
+ DAG.getConstant(16, dl,
+ TLI.getShiftAmountTy(MVT::i32, DAG.getDataLayout())));
----------------
RKSimon wrote:
Use DAG.getShiftAmountConstant?
https://github.com/llvm/llvm-project/pull/82399
More information about the llvm-commits
mailing list