[llvm] Correctly round FP -> BF16 when SDAG expands such nodes (PR #82399)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 20 10:34:58 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag
Author: David Majnemer (majnemer)
<details>
<summary>Changes</summary>
We did something pretty naive:
- round FP64 -> BF16 by first rounding to FP32
- skip FP32 -> BF16 rounding entirely
- taking the top 16 bits of a FP32 which will turn some NaNs into infinities
Let's do this in a more principled way by rounding types with more precision than FP32 to FP32 using round-inexact-to-odd which will negate double rounding issues.
---
Patch is 1.09 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82399.diff
11 Files Affected:
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+92-2)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+53)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+3)
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+5-5)
- (modified) llvm/test/CodeGen/AMDGPU/bf16.ll (+12511-2859)
- (modified) llvm/test/CodeGen/AMDGPU/fmed3-cast-combine.ll (+14-2)
- (modified) llvm/test/CodeGen/AMDGPU/global-atomics-fp.ll (+174-112)
- (modified) llvm/test/CodeGen/AMDGPU/isel-amdgpu-cs-chain-preserve-cc.ll (+987-475)
- (modified) llvm/test/CodeGen/AMDGPU/local-atomics-fp.ll (+66-38)
- (modified) llvm/test/CodeGen/AMDGPU/vector_shuffle.packed.ll (+213-80)
- (modified) llvm/test/CodeGen/NVPTX/bf16-instructions.ll (+1-1)
``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 252b6e9997a710..3426956a41b3d2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3219,8 +3219,98 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
case ISD::FP_ROUND: {
EVT VT = Node->getValueType(0);
if (VT.getScalarType() == MVT::bf16) {
- Results.push_back(
- DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
+ if (Node->getConstantOperandVal(1) == 1) {
+ Results.push_back(
+ DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
+ break;
+ }
+ SDValue Op = Node->getOperand(0);
+ SDValue IsNaN = DAG.getSetCC(dl, getSetCCResultType(Op.getValueType()),
+ Op, Op, ISD::SETUO);
+ if (Op.getValueType() != MVT::f32) {
+ // We are rounding binary64/binary128 -> binary32 -> bfloat16. This
+ // can induce double-rounding which may alter the results. We can
+ // correct for this using a trick explained in: Boldo, Sylvie, and
+ // Guillaume Melquiond. "When double rounding is odd." 17th IMACS
+ // World Congress. 2005.
+ FloatSignAsInt ValueAsInt;
+ getSignAsIntValue(ValueAsInt, dl, Op);
+ EVT WideIntVT = ValueAsInt.IntValue.getValueType();
+ SDValue SignMask = DAG.getConstant(ValueAsInt.SignMask, dl, WideIntVT);
+ SDValue SignBit =
+ DAG.getNode(ISD::AND, dl, WideIntVT, ValueAsInt.IntValue, SignMask);
+ SDValue AbsWide;
+ if (TLI.isOperationLegalOrCustom(ISD::FABS, ValueAsInt.FloatVT)) {
+ AbsWide = DAG.getNode(ISD::FABS, dl, ValueAsInt.FloatVT, Op);
+ } else {
+ SDValue ClearSignMask =
+ DAG.getConstant(~ValueAsInt.SignMask, dl, WideIntVT);
+ SDValue ClearedSign = DAG.getNode(ISD::AND, dl, WideIntVT,
+ ValueAsInt.IntValue, ClearSignMask);
+ AbsWide = modifySignAsInt(ValueAsInt, dl, ClearedSign);
+ }
+ SDValue AbsNarrow =
+ DAG.getNode(ISD::FP_ROUND, dl, MVT::f32, AbsWide,
+ DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
+ SDValue AbsNarrowAsWide =
+ DAG.getNode(ISD::FP_EXTEND, dl, ValueAsInt.FloatVT, AbsNarrow);
+
+ // We can keep the narrow value as-is if narrowing was exact (no
+ // rounding error), the wide value was NaN (the narrow value is also
+ // NaN and should be preserved) or if we rounded to the odd value.
+ SDValue NarrowBits = DAG.getNode(ISD::BITCAST, dl, MVT::i32, AbsNarrow);
+ SDValue One = DAG.getConstant(1, dl, MVT::i32);
+ SDValue NegativeOne = DAG.getConstant(-1, dl, MVT::i32);
+ SDValue And = DAG.getNode(ISD::AND, dl, MVT::i32, NarrowBits, One);
+ EVT I32CCVT = getSetCCResultType(And.getValueType());
+ SDValue Zero = DAG.getConstant(0, dl, MVT::i32);
+ SDValue AlreadyOdd = DAG.getSetCC(dl, I32CCVT, And, Zero, ISD::SETNE);
+
+ EVT WideSetCCVT = getSetCCResultType(AbsWide.getValueType());
+ SDValue KeepNarrow = DAG.getSetCC(dl, WideSetCCVT, AbsWide,
+ AbsNarrowAsWide, ISD::SETUEQ);
+ KeepNarrow =
+ DAG.getNode(ISD::OR, dl, WideSetCCVT, KeepNarrow, AlreadyOdd);
+ // We morally performed a round-down if `abs_narrow` is smaller than
+ // `abs_wide`.
+ SDValue NarrowIsRd = DAG.getSetCC(dl, WideSetCCVT, AbsWide,
+ AbsNarrowAsWide, ISD::SETOGT);
+ // If the narrow value is odd or exact, pick it.
+ // Otherwise, narrow is even and corresponds to either the rounded-up
+ // or rounded-down value. If narrow is the rounded-down value, we want
+ // the rounded-up value as it will be odd.
+ SDValue Adjust =
+ DAG.getSelect(dl, MVT::i32, NarrowIsRd, One, NegativeOne);
+ Adjust = DAG.getSelect(dl, MVT::i32, KeepNarrow, Zero, Adjust);
+ int ShiftAmount = ValueAsInt.SignBit - 31;
+ SDValue ShiftCnst = DAG.getConstant(
+ ShiftAmount, dl,
+ TLI.getShiftAmountTy(WideIntVT, DAG.getDataLayout()));
+ SignBit = DAG.getNode(ISD::SRL, dl, WideIntVT, SignBit, ShiftCnst);
+ SignBit = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, SignBit);
+ Op = DAG.getNode(ISD::OR, dl, MVT::i32, Adjust, SignBit);
+ } else {
+ Op = DAG.getNode(ISD::BITCAST, dl, MVT::i32, Op);
+ }
+
+ SDValue One = DAG.getConstant(1, dl, MVT::i32);
+ SDValue Lsb = DAG.getNode(
+ ISD::SRL, dl, MVT::i32, Op,
+ DAG.getConstant(16, dl,
+ TLI.getShiftAmountTy(MVT::i32, DAG.getDataLayout())));
+ Lsb = DAG.getNode(ISD::AND, dl, MVT::i32, Lsb, One);
+ SDValue RoundingBias = DAG.getNode(
+ ISD::ADD, dl, MVT::i32, DAG.getConstant(0x7fff, dl, MVT::i32), Lsb);
+ SDValue Add = DAG.getNode(ISD::ADD, dl, MVT::i32, Op, RoundingBias);
+ Op = DAG.getNode(
+ ISD::SRL, dl, MVT::i32, Add,
+ DAG.getConstant(16, dl,
+ TLI.getShiftAmountTy(MVT::i32, DAG.getDataLayout())));
+ Op = DAG.getSelect(dl, MVT::i32, IsNaN,
+ DAG.getConstant(0x00007fc0, dl, MVT::i32), Op);
+
+ Op = DAG.getNode(ISD::TRUNCATE, dl, MVT::i16, Op);
+ Results.push_back(DAG.getNode(ISD::BITCAST, dl, MVT::bf16, Op));
break;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 7f58b312e7a201..e75799ca13b0bb 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -776,6 +776,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
AddPromotedToType(Op, MVT::bf16, MVT::f32);
}
+ if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
+ setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom);
+ setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom);
+ }
+
// sm_80 only has conversions between f32 and bf16. Custom lower all other
// bf16 conversions.
if (STI.hasBF16Math() &&
@@ -2465,6 +2470,50 @@ SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
return Op;
}
+SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op,
+ SelectionDAG &DAG) const {
+ if (Op.getValueType() == MVT::bf16) {
+ if (Op.getOperand(0).getValueType() == MVT::f32 &&
+ (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70)) {
+ SDLoc Loc(Op);
+ return DAG.getNode(ISD::FP_TO_BF16, Loc, MVT::bf16, Op.getOperand(0));
+ }
+ if (Op.getOperand(0).getValueType() == MVT::f64 &&
+ (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
+ SDLoc Loc(Op);
+ return DAG.getNode(ISD::FP_TO_BF16, Loc, MVT::bf16, Op.getOperand(0));
+ }
+ }
+
+ // Everything else is considered legal.
+ return Op;
+}
+
+SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op,
+ SelectionDAG &DAG) const {
+ if (Op.getOperand(0).getValueType() == MVT::bf16) {
+ if (Op.getValueType() == MVT::f32 &&
+ (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) {
+ SDLoc Loc(Op);
+ return DAG.getNode(ISD::BF16_TO_FP, Loc, Op.getValueType(),
+ Op.getOperand(0));
+ }
+ if (Op.getValueType() == MVT::f64 &&
+ (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
+ SDLoc Loc(Op);
+ if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) {
+ Op = DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f32, Op.getOperand(0));
+ return DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f64, Op);
+ }
+ return DAG.getNode(ISD::BF16_TO_FP, Loc, Op.getValueType(),
+ Op.getOperand(0));
+ }
+ }
+
+ // Everything else is considered legal.
+ return Op;
+}
+
static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
SDLoc DL(Op);
if (Op.getValueType() != MVT::v2i16)
@@ -2527,6 +2576,10 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::FP_TO_SINT:
case ISD::FP_TO_UINT:
return LowerFP_TO_INT(Op, DAG);
+ case ISD::FP_ROUND:
+ return LowerFP_ROUND(Op, DAG);
+ case ISD::FP_EXTEND:
+ return LowerFP_EXTEND(Op, DAG);
case ISD::VAARG:
return LowerVAARG(Op, DAG);
case ISD::VASTART:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 5d3fd992812ef9..cf1d4580766918 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -618,6 +618,9 @@ class NVPTXTargetLowering : public TargetLowering {
SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const;
+
SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 631136ad621464..40d82ebecbed35 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -662,7 +662,7 @@ let hasSideEffects = false in {
// bf16->f32 was introduced early.
[hasPTX<71>, hasSM<80>],
// bf16->everything else needs sm90/ptx78
- [hasPTX<78>, hasSM<90>])>;
+ [hasPTX<78>, hasSM<90>])>;
def _f32 :
NVPTXInst<(outs RC:$dst),
(ins Float32Regs:$src, CvtMode:$mode),
@@ -3647,7 +3647,7 @@ def : Pat<(f16 (fpround Float32Regs:$a)),
// fpround f32 -> bf16
def : Pat<(bf16 (fpround Float32Regs:$a)),
- (CVT_bf16_f32 Float32Regs:$a, CvtRN)>;
+ (CVT_bf16_f32 Float32Regs:$a, CvtRN)>, Requires<[hasPTX<70>, hasSM<80>]>;
// fpround f64 -> f16
def : Pat<(f16 (fpround Float64Regs:$a)),
@@ -3655,7 +3655,7 @@ def : Pat<(f16 (fpround Float64Regs:$a)),
// fpround f64 -> bf16
def : Pat<(bf16 (fpround Float64Regs:$a)),
- (CVT_bf16_f64 Float64Regs:$a, CvtRN)>;
+ (CVT_bf16_f64 Float64Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
// fpround f64 -> f32
def : Pat<(f32 (fpround Float64Regs:$a)),
(CVT_f32_f64 Float64Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>;
@@ -3671,7 +3671,7 @@ def : Pat<(f32 (fpextend (f16 Int16Regs:$a))),
def : Pat<(f32 (fpextend (bf16 Int16Regs:$a))),
(CVT_f32_bf16 Int16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
def : Pat<(f32 (fpextend (bf16 Int16Regs:$a))),
- (CVT_f32_bf16 Int16Regs:$a, CvtNONE)>;
+ (CVT_f32_bf16 Int16Regs:$a, CvtNONE)>, Requires<[hasPTX<71>, hasSM<80>]>;
// fpextend f16 -> f64
def : Pat<(f64 (fpextend (f16 Int16Regs:$a))),
@@ -3679,7 +3679,7 @@ def : Pat<(f64 (fpextend (f16 Int16Regs:$a))),
// fpextend bf16 -> f64
def : Pat<(f64 (fpextend (bf16 Int16Regs:$a))),
- (CVT_f64_bf16 Int16Regs:$a, CvtNONE)>;
+ (CVT_f64_bf16 Int16Regs:$a, CvtNONE)>, Requires<[hasPTX<78>, hasSM<90>]>;
// fpextend f32 -> f64
def : Pat<(f64 (fpextend Float32Regs:$a)),
diff --git a/llvm/test/CodeGen/AMDGPU/bf16.ll b/llvm/test/CodeGen/AMDGPU/bf16.ll
index 387c4a16a008ae..39cb0a768701c0 100644
--- a/llvm/test/CodeGen/AMDGPU/bf16.ll
+++ b/llvm/test/CodeGen/AMDGPU/bf16.ll
@@ -1918,8 +1918,14 @@ define void @test_load_store_f32_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
; GFX8: ; %bb.0:
; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GFX8-NEXT: flat_load_dword v0, v[0:1]
+; GFX8-NEXT: v_mov_b32_e32 v1, 0x7fc0
; GFX8-NEXT: s_waitcnt vmcnt(0)
-; GFX8-NEXT: v_lshrrev_b32_e32 v0, 16, v0
+; GFX8-NEXT: v_bfe_u32 v4, v0, 16, 1
+; GFX8-NEXT: v_add_u32_e32 v4, vcc, v4, v0
+; GFX8-NEXT: v_add_u32_e32 v4, vcc, 0x7fff, v4
+; GFX8-NEXT: v_lshrrev_b32_e32 v4, 16, v4
+; GFX8-NEXT: v_cmp_o_f32_e32 vcc, v0, v0
+; GFX8-NEXT: v_cndmask_b32_e32 v0, v1, v4, vcc
; GFX8-NEXT: flat_store_short v[2:3], v0
; GFX8-NEXT: s_waitcnt vmcnt(0)
; GFX8-NEXT: s_setpc_b64 s[30:31]
@@ -1928,8 +1934,15 @@ define void @test_load_store_f32_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
; GFX9: ; %bb.0:
; GFX9-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GFX9-NEXT: global_load_dword v0, v[0:1], off
+; GFX9-NEXT: s_movk_i32 s4, 0x7fff
+; GFX9-NEXT: v_mov_b32_e32 v1, 0x7fc0
; GFX9-NEXT: s_waitcnt vmcnt(0)
-; GFX9-NEXT: global_store_short_d16_hi v[2:3], v0, off
+; GFX9-NEXT: v_bfe_u32 v4, v0, 16, 1
+; GFX9-NEXT: v_add3_u32 v4, v4, v0, s4
+; GFX9-NEXT: v_lshrrev_b32_e32 v4, 16, v4
+; GFX9-NEXT: v_cmp_o_f32_e32 vcc, v0, v0
+; GFX9-NEXT: v_cndmask_b32_e32 v0, v1, v4, vcc
+; GFX9-NEXT: global_store_short v[2:3], v0, off
; GFX9-NEXT: s_waitcnt vmcnt(0)
; GFX9-NEXT: s_setpc_b64 s[30:31]
;
@@ -1938,7 +1951,12 @@ define void @test_load_store_f32_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GFX10-NEXT: global_load_dword v0, v[0:1], off
; GFX10-NEXT: s_waitcnt vmcnt(0)
-; GFX10-NEXT: global_store_short_d16_hi v[2:3], v0, off
+; GFX10-NEXT: v_bfe_u32 v1, v0, 16, 1
+; GFX10-NEXT: v_cmp_o_f32_e32 vcc_lo, v0, v0
+; GFX10-NEXT: v_add3_u32 v1, v1, v0, 0x7fff
+; GFX10-NEXT: v_lshrrev_b32_e32 v1, 16, v1
+; GFX10-NEXT: v_cndmask_b32_e32 v0, 0x7fc0, v1, vcc_lo
+; GFX10-NEXT: global_store_short v[2:3], v0, off
; GFX10-NEXT: s_setpc_b64 s[30:31]
;
; GFX11-LABEL: test_load_store_f32_to_bf16:
@@ -1946,7 +1964,14 @@ define void @test_load_store_f32_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GFX11-NEXT: global_load_b32 v0, v[0:1], off
; GFX11-NEXT: s_waitcnt vmcnt(0)
-; GFX11-NEXT: global_store_d16_hi_b16 v[2:3], v0, off
+; GFX11-NEXT: v_bfe_u32 v1, v0, 16, 1
+; GFX11-NEXT: v_cmp_o_f32_e32 vcc_lo, v0, v0
+; GFX11-NEXT: s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX11-NEXT: v_add3_u32 v1, v1, v0, 0x7fff
+; GFX11-NEXT: v_lshrrev_b32_e32 v1, 16, v1
+; GFX11-NEXT: s_delay_alu instid0(VALU_DEP_1)
+; GFX11-NEXT: v_cndmask_b32_e32 v0, 0x7fc0, v1, vcc_lo
+; GFX11-NEXT: global_store_b16 v[2:3], v0, off
; GFX11-NEXT: s_setpc_b64 s[30:31]
%val = load float, ptr addrspace(1) %in
%val.bf16 = fptrunc float %val to bfloat
@@ -1989,9 +2014,25 @@ define void @test_load_store_f64_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
; GFX8: ; %bb.0:
; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GFX8-NEXT: flat_load_dwordx2 v[0:1], v[0:1]
+; GFX8-NEXT: v_mov_b32_e32 v7, 0x7fc0
; GFX8-NEXT: s_waitcnt vmcnt(0)
-; GFX8-NEXT: v_cvt_f32_f64_e32 v0, v[0:1]
-; GFX8-NEXT: v_lshrrev_b32_e32 v0, 16, v0
+; GFX8-NEXT: v_cvt_f32_f64_e64 v6, |v[0:1]|
+; GFX8-NEXT: v_and_b32_e32 v8, 0x80000000, v1
+; GFX8-NEXT: v_cvt_f64_f32_e32 v[4:5], v6
+; GFX8-NEXT: v_and_b32_e32 v6, 1, v6
+; GFX8-NEXT: v_cmp_eq_u32_e32 vcc, 1, v6
+; GFX8-NEXT: v_cmp_nlg_f64_e64 s[4:5], |v[0:1]|, v[4:5]
+; GFX8-NEXT: v_cmp_gt_f64_e64 s[6:7], |v[0:1]|, v[4:5]
+; GFX8-NEXT: s_or_b64 s[4:5], s[4:5], vcc
+; GFX8-NEXT: v_cndmask_b32_e64 v4, -1, 1, s[6:7]
+; GFX8-NEXT: v_cndmask_b32_e64 v4, v4, 0, s[4:5]
+; GFX8-NEXT: v_or_b32_e32 v5, v4, v8
+; GFX8-NEXT: v_bfe_u32 v4, v4, 16, 1
+; GFX8-NEXT: v_add_u32_e32 v4, vcc, v4, v5
+; GFX8-NEXT: v_add_u32_e32 v4, vcc, 0x7fff, v4
+; GFX8-NEXT: v_cmp_o_f64_e32 vcc, v[0:1], v[0:1]
+; GFX8-NEXT: v_lshrrev_b32_e32 v4, 16, v4
+; GFX8-NEXT: v_cndmask_b32_e32 v0, v7, v4, vcc
; GFX8-NEXT: flat_store_short v[2:3], v0
; GFX8-NEXT: s_waitcnt vmcnt(0)
; GFX8-NEXT: s_setpc_b64 s[30:31]
@@ -2000,9 +2041,26 @@ define void @test_load_store_f64_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
; GFX9: ; %bb.0:
; GFX9-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GFX9-NEXT: global_load_dwordx2 v[0:1], v[0:1], off
+; GFX9-NEXT: s_brev_b32 s8, 1
+; GFX9-NEXT: s_movk_i32 s9, 0x7fff
+; GFX9-NEXT: v_mov_b32_e32 v7, 0x7fc0
; GFX9-NEXT: s_waitcnt vmcnt(0)
-; GFX9-NEXT: v_cvt_f32_f64_e32 v0, v[0:1]
-; GFX9-NEXT: global_store_short_d16_hi v[2:3], v0, off
+; GFX9-NEXT: v_cvt_f32_f64_e64 v6, |v[0:1]|
+; GFX9-NEXT: v_cvt_f64_f32_e32 v[4:5], v6
+; GFX9-NEXT: v_and_b32_e32 v6, 1, v6
+; GFX9-NEXT: v_cmp_eq_u32_e32 vcc, 1, v6
+; GFX9-NEXT: v_cmp_nlg_f64_e64 s[4:5], |v[0:1]|, v[4:5]
+; GFX9-NEXT: v_cmp_gt_f64_e64 s[6:7], |v[0:1]|, v[4:5]
+; GFX9-NEXT: s_or_b64 s[4:5], s[4:5], vcc
+; GFX9-NEXT: v_cmp_o_f64_e32 vcc, v[0:1], v[0:1]
+; GFX9-NEXT: v_cndmask_b32_e64 v4, -1, 1, s[6:7]
+; GFX9-NEXT: v_cndmask_b32_e64 v4, v4, 0, s[4:5]
+; GFX9-NEXT: v_and_or_b32 v5, v1, s8, v4
+; GFX9-NEXT: v_bfe_u32 v4, v4, 16, 1
+; GFX9-NEXT: v_add3_u32 v4, v4, v5, s9
+; GFX9-NEXT: v_lshrrev_b32_e32 v4, 16, v4
+; GFX9-NEXT: v_cndmask_b32_e32 v0, v7, v4, vcc
+; GFX9-NEXT: global_store_short v[2:3], v0, off
; GFX9-NEXT: s_waitcnt vmcnt(0)
; GFX9-NEXT: s_setpc_b64 s[30:31]
;
@@ -2011,8 +2069,22 @@ define void @test_load_store_f64_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off
; GFX10-NEXT: s_waitcnt vmcnt(0)
-; GFX10-NEXT: v_cvt_f32_f64_e32 v0, v[0:1]
-; GFX10-NEXT: global_store_short_d16_hi v[2:3], v0, off
+; GFX10-NEXT: v_cvt_f32_f64_e64 v6, |v[0:1]|
+; GFX10-NEXT: v_cvt_f64_f32_e32 v[4:5], v6
+; GFX10-NEXT: v_and_b32_e32 v6, 1, v6
+; GFX10-NEXT: v_cmp_eq_u32_e32 vcc_lo, 1, v6
+; GFX10-NEXT: v_cmp_gt_f64_e64 s5, |v[0:1]|, v[4:5]
+; GFX10-NEXT: v_cmp_nlg_f64_e64 s4, |v[0:1]|, v[4:5]
+; GFX10-NEXT: v_cndmask_b32_e64 v4, -1, 1, s5
+; GFX10-NEXT: s_or_b32 s4, s4, vcc_lo
+; GFX10-NEXT: v_cmp_o_f64_e32 vcc_lo, v[0:1], v[0:1]
+; GFX10-NEXT: v_cndmask_b32_e64 v4, v4, 0, s4
+; GFX10-NEXT: v_and_or_b32 v5, 0x80000000, v1, v4
+; GFX10-NEXT: v_bfe_u32 v4, v4, 16, 1
+; GFX10-NEXT: v_add3_u32 v4, v4, v5, 0x7fff
+; GFX10-NEXT: v_lshrrev_b32_e32 v4, 16, v4
+; GFX10-NEXT: v_cndmask_b32_e32 v0, 0x7fc0, v4, vcc_lo
+; GFX10-NEXT: global_store_short v[2:3], v0, off
; GFX10-NEXT: s_setpc_b64 s[30:31]
;
; GFX11-LABEL: test_load_store_f64_to_bf16:
@@ -2020,8 +2092,27 @@ define void @test_load_store_f64_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off
; GFX11-NEXT: s_waitcnt vmcnt(0)
-; GFX11-NEXT: v_cvt_f32_f64_e32 v0, v[0:1]
-; GFX11-NEXT: global_store_d16_hi_b16 v[2:3], v0, off
+; GFX11-NEXT: v_cvt_f32_f64_e64 v6, |v[0:1]|
+; GFX11-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1)
+; GFX11-NEXT: v_cvt_f64_f32_e32 v[4:5], v6
+; GFX11-NEXT: v_and_b32_e32 v6, 1, v6
+; GFX11-NEXT: v_cmp_eq_u32_e32 vcc_lo, 1, v6
+; GFX11-NEXT: s_delay_alu instid0(VALU_DEP_3) | instskip(SKIP_1) | instid1(VALU_DEP_2)
+; GFX11-NEXT: v_cmp_nlg_f64_e64 s0, |v[0:1]|, v[4:5]
+; GFX11-NEXT: v_cmp_gt_f64_e64 s1, |v[0:1]|, v[4:5]
+; GFX11-NEXT: s_or_b32 s0, s0, vcc_lo
+; GFX11-NEXT: v_cmp_o_f64_e32 vcc_lo, v[0:1], v[0:1]
+; GFX11-NEXT: s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX11-NEXT: v_cndmask_b32_e64 v4, -1, 1, s1
+; GFX11-NEXT: v_cndmask_b32_e64 v4, v4, 0, s0
+; GFX11-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1)
+; GFX11-NEXT: v_and_or_b32 v5, 0x80000000, v1, v4
+; GFX11-NEXT: v_bfe_u32 v4, v4, 16, 1
+; GFX11-NEXT: v_add3_u32 v4, v4, v5, 0x7fff
+; GFX11-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX11-NEXT: v_lshrrev_b32_e32 v4, 16, v4
+; GFX11-NEXT: v_cndmask_b32_e32 v0, 0x7fc0, v4, vcc_lo
+; GFX11-NEXT: global_store_b16 v[2:3], v0, off
; GFX11-NEXT: s_setpc_b64 s[30:31]
%val = load double, ptr addrspace(1) %in
%val.bf16 = fptrunc double %val to bfloat
@@ -8487,7 +8578,13 @@ define bfloat ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/82399
More information about the llvm-commits
mailing list