[llvm] Correctly round FP -> BF16 when SDAG expands such nodes (PR #82399)

via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 20 10:34:58 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: David Majnemer (majnemer)

<details>
<summary>Changes</summary>

We did something pretty naive:
- round FP64 -> BF16 by first rounding to FP32
- skip FP32 -> BF16 rounding entirely
- taking the top 16 bits of a FP32 which will turn some NaNs into infinities

Let's do this in a more principled way by rounding types with more precision than FP32 to FP32 using round-inexact-to-odd which will negate double rounding issues.

---

Patch is 1.09 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82399.diff


11 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+92-2) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+53) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+3) 
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+5-5) 
- (modified) llvm/test/CodeGen/AMDGPU/bf16.ll (+12511-2859) 
- (modified) llvm/test/CodeGen/AMDGPU/fmed3-cast-combine.ll (+14-2) 
- (modified) llvm/test/CodeGen/AMDGPU/global-atomics-fp.ll (+174-112) 
- (modified) llvm/test/CodeGen/AMDGPU/isel-amdgpu-cs-chain-preserve-cc.ll (+987-475) 
- (modified) llvm/test/CodeGen/AMDGPU/local-atomics-fp.ll (+66-38) 
- (modified) llvm/test/CodeGen/AMDGPU/vector_shuffle.packed.ll (+213-80) 
- (modified) llvm/test/CodeGen/NVPTX/bf16-instructions.ll (+1-1) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 252b6e9997a710..3426956a41b3d2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3219,8 +3219,98 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
   case ISD::FP_ROUND: {
     EVT VT = Node->getValueType(0);
     if (VT.getScalarType() == MVT::bf16) {
-      Results.push_back(
-          DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
+      if (Node->getConstantOperandVal(1) == 1) {
+        Results.push_back(
+            DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
+        break;
+      }
+      SDValue Op = Node->getOperand(0);
+      SDValue IsNaN = DAG.getSetCC(dl, getSetCCResultType(Op.getValueType()),
+                                   Op, Op, ISD::SETUO);
+      if (Op.getValueType() != MVT::f32) {
+        // We are rounding binary64/binary128 -> binary32 -> bfloat16. This
+        // can induce double-rounding which may alter the results. We can
+        // correct for this using a trick explained in: Boldo, Sylvie, and
+        // Guillaume Melquiond. "When double rounding is odd." 17th IMACS
+        // World Congress. 2005.
+        FloatSignAsInt ValueAsInt;
+        getSignAsIntValue(ValueAsInt, dl, Op);
+        EVT WideIntVT = ValueAsInt.IntValue.getValueType();
+        SDValue SignMask = DAG.getConstant(ValueAsInt.SignMask, dl, WideIntVT);
+        SDValue SignBit =
+            DAG.getNode(ISD::AND, dl, WideIntVT, ValueAsInt.IntValue, SignMask);
+        SDValue AbsWide;
+        if (TLI.isOperationLegalOrCustom(ISD::FABS, ValueAsInt.FloatVT)) {
+          AbsWide = DAG.getNode(ISD::FABS, dl, ValueAsInt.FloatVT, Op);
+        } else {
+          SDValue ClearSignMask =
+              DAG.getConstant(~ValueAsInt.SignMask, dl, WideIntVT);
+          SDValue ClearedSign = DAG.getNode(ISD::AND, dl, WideIntVT,
+                                            ValueAsInt.IntValue, ClearSignMask);
+          AbsWide = modifySignAsInt(ValueAsInt, dl, ClearedSign);
+        }
+        SDValue AbsNarrow =
+            DAG.getNode(ISD::FP_ROUND, dl, MVT::f32, AbsWide,
+                        DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
+        SDValue AbsNarrowAsWide =
+            DAG.getNode(ISD::FP_EXTEND, dl, ValueAsInt.FloatVT, AbsNarrow);
+
+        // We can keep the narrow value as-is if narrowing was exact (no
+        // rounding error), the wide value was NaN (the narrow value is also
+        // NaN and should be preserved) or if we rounded to the odd value.
+        SDValue NarrowBits = DAG.getNode(ISD::BITCAST, dl, MVT::i32, AbsNarrow);
+        SDValue One = DAG.getConstant(1, dl, MVT::i32);
+        SDValue NegativeOne = DAG.getConstant(-1, dl, MVT::i32);
+        SDValue And = DAG.getNode(ISD::AND, dl, MVT::i32, NarrowBits, One);
+        EVT I32CCVT = getSetCCResultType(And.getValueType());
+        SDValue Zero = DAG.getConstant(0, dl, MVT::i32);
+        SDValue AlreadyOdd = DAG.getSetCC(dl, I32CCVT, And, Zero, ISD::SETNE);
+
+        EVT WideSetCCVT = getSetCCResultType(AbsWide.getValueType());
+        SDValue KeepNarrow = DAG.getSetCC(dl, WideSetCCVT, AbsWide,
+                                          AbsNarrowAsWide, ISD::SETUEQ);
+        KeepNarrow =
+            DAG.getNode(ISD::OR, dl, WideSetCCVT, KeepNarrow, AlreadyOdd);
+        // We morally performed a round-down if `abs_narrow` is smaller than
+        // `abs_wide`.
+        SDValue NarrowIsRd = DAG.getSetCC(dl, WideSetCCVT, AbsWide,
+                                          AbsNarrowAsWide, ISD::SETOGT);
+        // If the narrow value is odd or exact, pick it.
+        // Otherwise, narrow is even and corresponds to either the rounded-up
+        // or rounded-down value. If narrow is the rounded-down value, we want
+        // the rounded-up value as it will be odd.
+        SDValue Adjust =
+            DAG.getSelect(dl, MVT::i32, NarrowIsRd, One, NegativeOne);
+        Adjust = DAG.getSelect(dl, MVT::i32, KeepNarrow, Zero, Adjust);
+        int ShiftAmount = ValueAsInt.SignBit - 31;
+        SDValue ShiftCnst = DAG.getConstant(
+            ShiftAmount, dl,
+            TLI.getShiftAmountTy(WideIntVT, DAG.getDataLayout()));
+        SignBit = DAG.getNode(ISD::SRL, dl, WideIntVT, SignBit, ShiftCnst);
+        SignBit = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, SignBit);
+        Op = DAG.getNode(ISD::OR, dl, MVT::i32, Adjust, SignBit);
+      } else {
+        Op = DAG.getNode(ISD::BITCAST, dl, MVT::i32, Op);
+      }
+
+      SDValue One = DAG.getConstant(1, dl, MVT::i32);
+      SDValue Lsb = DAG.getNode(
+          ISD::SRL, dl, MVT::i32, Op,
+          DAG.getConstant(16, dl,
+                          TLI.getShiftAmountTy(MVT::i32, DAG.getDataLayout())));
+      Lsb = DAG.getNode(ISD::AND, dl, MVT::i32, Lsb, One);
+      SDValue RoundingBias = DAG.getNode(
+          ISD::ADD, dl, MVT::i32, DAG.getConstant(0x7fff, dl, MVT::i32), Lsb);
+      SDValue Add = DAG.getNode(ISD::ADD, dl, MVT::i32, Op, RoundingBias);
+      Op = DAG.getNode(
+          ISD::SRL, dl, MVT::i32, Add,
+          DAG.getConstant(16, dl,
+                          TLI.getShiftAmountTy(MVT::i32, DAG.getDataLayout())));
+      Op = DAG.getSelect(dl, MVT::i32, IsNaN,
+                         DAG.getConstant(0x00007fc0, dl, MVT::i32), Op);
+
+      Op = DAG.getNode(ISD::TRUNCATE, dl, MVT::i16, Op);
+      Results.push_back(DAG.getNode(ISD::BITCAST, dl, MVT::bf16, Op));
       break;
     }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 7f58b312e7a201..e75799ca13b0bb 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -776,6 +776,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
       AddPromotedToType(Op, MVT::bf16, MVT::f32);
   }
 
+  if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
+    setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom);
+    setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom);
+  }
+
   // sm_80 only has conversions between f32 and bf16. Custom lower all other
   // bf16 conversions.
   if (STI.hasBF16Math() &&
@@ -2465,6 +2470,50 @@ SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
   return Op;
 }
 
+SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op,
+                                           SelectionDAG &DAG) const {
+  if (Op.getValueType() == MVT::bf16) {
+    if (Op.getOperand(0).getValueType() == MVT::f32 &&
+        (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70)) {
+      SDLoc Loc(Op);
+      return DAG.getNode(ISD::FP_TO_BF16, Loc, MVT::bf16, Op.getOperand(0));
+    }
+    if (Op.getOperand(0).getValueType() == MVT::f64 &&
+        (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
+      SDLoc Loc(Op);
+      return DAG.getNode(ISD::FP_TO_BF16, Loc, MVT::bf16, Op.getOperand(0));
+    }
+  }
+
+  // Everything else is considered legal.
+  return Op;
+}
+
+SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op,
+                                            SelectionDAG &DAG) const {
+  if (Op.getOperand(0).getValueType() == MVT::bf16) {
+    if (Op.getValueType() == MVT::f32 &&
+        (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) {
+      SDLoc Loc(Op);
+      return DAG.getNode(ISD::BF16_TO_FP, Loc, Op.getValueType(),
+                         Op.getOperand(0));
+    }
+    if (Op.getValueType() == MVT::f64 &&
+        (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
+      SDLoc Loc(Op);
+      if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) {
+        Op = DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f32, Op.getOperand(0));
+        return DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f64, Op);
+      }
+      return DAG.getNode(ISD::BF16_TO_FP, Loc, Op.getValueType(),
+                         Op.getOperand(0));
+    }
+  }
+
+  // Everything else is considered legal.
+  return Op;
+}
+
 static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
   SDLoc DL(Op);
   if (Op.getValueType() != MVT::v2i16)
@@ -2527,6 +2576,10 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   case ISD::FP_TO_SINT:
   case ISD::FP_TO_UINT:
     return LowerFP_TO_INT(Op, DAG);
+  case ISD::FP_ROUND:
+    return LowerFP_ROUND(Op, DAG);
+  case ISD::FP_EXTEND:
+    return LowerFP_EXTEND(Op, DAG);
   case ISD::VAARG:
     return LowerVAARG(Op, DAG);
   case ISD::VASTART:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 5d3fd992812ef9..cf1d4580766918 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -618,6 +618,9 @@ class NVPTXTargetLowering : public TargetLowering {
   SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const;
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 631136ad621464..40d82ebecbed35 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -662,7 +662,7 @@ let hasSideEffects = false in {
                    // bf16->f32 was introduced early.
                    [hasPTX<71>, hasSM<80>],
                    // bf16->everything else needs sm90/ptx78
-                   [hasPTX<78>, hasSM<90>])>; 
+                   [hasPTX<78>, hasSM<90>])>;
     def _f32 :
       NVPTXInst<(outs RC:$dst),
                 (ins Float32Regs:$src, CvtMode:$mode),
@@ -3647,7 +3647,7 @@ def : Pat<(f16 (fpround Float32Regs:$a)),
 
 // fpround f32 -> bf16
 def : Pat<(bf16 (fpround Float32Regs:$a)),
-          (CVT_bf16_f32 Float32Regs:$a, CvtRN)>;
+          (CVT_bf16_f32 Float32Regs:$a, CvtRN)>, Requires<[hasPTX<70>, hasSM<80>]>;
 
 // fpround f64 -> f16
 def : Pat<(f16 (fpround Float64Regs:$a)),
@@ -3655,7 +3655,7 @@ def : Pat<(f16 (fpround Float64Regs:$a)),
 
 // fpround f64 -> bf16
 def : Pat<(bf16 (fpround Float64Regs:$a)),
-          (CVT_bf16_f64 Float64Regs:$a, CvtRN)>;
+          (CVT_bf16_f64 Float64Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
 // fpround f64 -> f32
 def : Pat<(f32 (fpround Float64Regs:$a)),
           (CVT_f32_f64 Float64Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>;
@@ -3671,7 +3671,7 @@ def : Pat<(f32 (fpextend (f16 Int16Regs:$a))),
 def : Pat<(f32 (fpextend (bf16 Int16Regs:$a))),
           (CVT_f32_bf16 Int16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
 def : Pat<(f32 (fpextend (bf16 Int16Regs:$a))),
-          (CVT_f32_bf16 Int16Regs:$a, CvtNONE)>;
+          (CVT_f32_bf16 Int16Regs:$a, CvtNONE)>, Requires<[hasPTX<71>, hasSM<80>]>;
 
 // fpextend f16 -> f64
 def : Pat<(f64 (fpextend (f16 Int16Regs:$a))),
@@ -3679,7 +3679,7 @@ def : Pat<(f64 (fpextend (f16 Int16Regs:$a))),
 
 // fpextend bf16 -> f64
 def : Pat<(f64 (fpextend (bf16 Int16Regs:$a))),
-          (CVT_f64_bf16 Int16Regs:$a, CvtNONE)>;
+          (CVT_f64_bf16 Int16Regs:$a, CvtNONE)>, Requires<[hasPTX<78>, hasSM<90>]>;
 
 // fpextend f32 -> f64
 def : Pat<(f64 (fpextend Float32Regs:$a)),
diff --git a/llvm/test/CodeGen/AMDGPU/bf16.ll b/llvm/test/CodeGen/AMDGPU/bf16.ll
index 387c4a16a008ae..39cb0a768701c0 100644
--- a/llvm/test/CodeGen/AMDGPU/bf16.ll
+++ b/llvm/test/CodeGen/AMDGPU/bf16.ll
@@ -1918,8 +1918,14 @@ define void @test_load_store_f32_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
 ; GFX8:       ; %bb.0:
 ; GFX8-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX8-NEXT:    flat_load_dword v0, v[0:1]
+; GFX8-NEXT:    v_mov_b32_e32 v1, 0x7fc0
 ; GFX8-NEXT:    s_waitcnt vmcnt(0)
-; GFX8-NEXT:    v_lshrrev_b32_e32 v0, 16, v0
+; GFX8-NEXT:    v_bfe_u32 v4, v0, 16, 1
+; GFX8-NEXT:    v_add_u32_e32 v4, vcc, v4, v0
+; GFX8-NEXT:    v_add_u32_e32 v4, vcc, 0x7fff, v4
+; GFX8-NEXT:    v_lshrrev_b32_e32 v4, 16, v4
+; GFX8-NEXT:    v_cmp_o_f32_e32 vcc, v0, v0
+; GFX8-NEXT:    v_cndmask_b32_e32 v0, v1, v4, vcc
 ; GFX8-NEXT:    flat_store_short v[2:3], v0
 ; GFX8-NEXT:    s_waitcnt vmcnt(0)
 ; GFX8-NEXT:    s_setpc_b64 s[30:31]
@@ -1928,8 +1934,15 @@ define void @test_load_store_f32_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
 ; GFX9:       ; %bb.0:
 ; GFX9-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX9-NEXT:    global_load_dword v0, v[0:1], off
+; GFX9-NEXT:    s_movk_i32 s4, 0x7fff
+; GFX9-NEXT:    v_mov_b32_e32 v1, 0x7fc0
 ; GFX9-NEXT:    s_waitcnt vmcnt(0)
-; GFX9-NEXT:    global_store_short_d16_hi v[2:3], v0, off
+; GFX9-NEXT:    v_bfe_u32 v4, v0, 16, 1
+; GFX9-NEXT:    v_add3_u32 v4, v4, v0, s4
+; GFX9-NEXT:    v_lshrrev_b32_e32 v4, 16, v4
+; GFX9-NEXT:    v_cmp_o_f32_e32 vcc, v0, v0
+; GFX9-NEXT:    v_cndmask_b32_e32 v0, v1, v4, vcc
+; GFX9-NEXT:    global_store_short v[2:3], v0, off
 ; GFX9-NEXT:    s_waitcnt vmcnt(0)
 ; GFX9-NEXT:    s_setpc_b64 s[30:31]
 ;
@@ -1938,7 +1951,12 @@ define void @test_load_store_f32_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
 ; GFX10-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX10-NEXT:    global_load_dword v0, v[0:1], off
 ; GFX10-NEXT:    s_waitcnt vmcnt(0)
-; GFX10-NEXT:    global_store_short_d16_hi v[2:3], v0, off
+; GFX10-NEXT:    v_bfe_u32 v1, v0, 16, 1
+; GFX10-NEXT:    v_cmp_o_f32_e32 vcc_lo, v0, v0
+; GFX10-NEXT:    v_add3_u32 v1, v1, v0, 0x7fff
+; GFX10-NEXT:    v_lshrrev_b32_e32 v1, 16, v1
+; GFX10-NEXT:    v_cndmask_b32_e32 v0, 0x7fc0, v1, vcc_lo
+; GFX10-NEXT:    global_store_short v[2:3], v0, off
 ; GFX10-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; GFX11-LABEL: test_load_store_f32_to_bf16:
@@ -1946,7 +1964,14 @@ define void @test_load_store_f32_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
 ; GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX11-NEXT:    global_load_b32 v0, v[0:1], off
 ; GFX11-NEXT:    s_waitcnt vmcnt(0)
-; GFX11-NEXT:    global_store_d16_hi_b16 v[2:3], v0, off
+; GFX11-NEXT:    v_bfe_u32 v1, v0, 16, 1
+; GFX11-NEXT:    v_cmp_o_f32_e32 vcc_lo, v0, v0
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX11-NEXT:    v_add3_u32 v1, v1, v0, 0x7fff
+; GFX11-NEXT:    v_lshrrev_b32_e32 v1, 16, v1
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_1)
+; GFX11-NEXT:    v_cndmask_b32_e32 v0, 0x7fc0, v1, vcc_lo
+; GFX11-NEXT:    global_store_b16 v[2:3], v0, off
 ; GFX11-NEXT:    s_setpc_b64 s[30:31]
   %val = load float, ptr addrspace(1) %in
   %val.bf16 = fptrunc float %val to bfloat
@@ -1989,9 +2014,25 @@ define void @test_load_store_f64_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
 ; GFX8:       ; %bb.0:
 ; GFX8-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX8-NEXT:    flat_load_dwordx2 v[0:1], v[0:1]
+; GFX8-NEXT:    v_mov_b32_e32 v7, 0x7fc0
 ; GFX8-NEXT:    s_waitcnt vmcnt(0)
-; GFX8-NEXT:    v_cvt_f32_f64_e32 v0, v[0:1]
-; GFX8-NEXT:    v_lshrrev_b32_e32 v0, 16, v0
+; GFX8-NEXT:    v_cvt_f32_f64_e64 v6, |v[0:1]|
+; GFX8-NEXT:    v_and_b32_e32 v8, 0x80000000, v1
+; GFX8-NEXT:    v_cvt_f64_f32_e32 v[4:5], v6
+; GFX8-NEXT:    v_and_b32_e32 v6, 1, v6
+; GFX8-NEXT:    v_cmp_eq_u32_e32 vcc, 1, v6
+; GFX8-NEXT:    v_cmp_nlg_f64_e64 s[4:5], |v[0:1]|, v[4:5]
+; GFX8-NEXT:    v_cmp_gt_f64_e64 s[6:7], |v[0:1]|, v[4:5]
+; GFX8-NEXT:    s_or_b64 s[4:5], s[4:5], vcc
+; GFX8-NEXT:    v_cndmask_b32_e64 v4, -1, 1, s[6:7]
+; GFX8-NEXT:    v_cndmask_b32_e64 v4, v4, 0, s[4:5]
+; GFX8-NEXT:    v_or_b32_e32 v5, v4, v8
+; GFX8-NEXT:    v_bfe_u32 v4, v4, 16, 1
+; GFX8-NEXT:    v_add_u32_e32 v4, vcc, v4, v5
+; GFX8-NEXT:    v_add_u32_e32 v4, vcc, 0x7fff, v4
+; GFX8-NEXT:    v_cmp_o_f64_e32 vcc, v[0:1], v[0:1]
+; GFX8-NEXT:    v_lshrrev_b32_e32 v4, 16, v4
+; GFX8-NEXT:    v_cndmask_b32_e32 v0, v7, v4, vcc
 ; GFX8-NEXT:    flat_store_short v[2:3], v0
 ; GFX8-NEXT:    s_waitcnt vmcnt(0)
 ; GFX8-NEXT:    s_setpc_b64 s[30:31]
@@ -2000,9 +2041,26 @@ define void @test_load_store_f64_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
 ; GFX9:       ; %bb.0:
 ; GFX9-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX9-NEXT:    global_load_dwordx2 v[0:1], v[0:1], off
+; GFX9-NEXT:    s_brev_b32 s8, 1
+; GFX9-NEXT:    s_movk_i32 s9, 0x7fff
+; GFX9-NEXT:    v_mov_b32_e32 v7, 0x7fc0
 ; GFX9-NEXT:    s_waitcnt vmcnt(0)
-; GFX9-NEXT:    v_cvt_f32_f64_e32 v0, v[0:1]
-; GFX9-NEXT:    global_store_short_d16_hi v[2:3], v0, off
+; GFX9-NEXT:    v_cvt_f32_f64_e64 v6, |v[0:1]|
+; GFX9-NEXT:    v_cvt_f64_f32_e32 v[4:5], v6
+; GFX9-NEXT:    v_and_b32_e32 v6, 1, v6
+; GFX9-NEXT:    v_cmp_eq_u32_e32 vcc, 1, v6
+; GFX9-NEXT:    v_cmp_nlg_f64_e64 s[4:5], |v[0:1]|, v[4:5]
+; GFX9-NEXT:    v_cmp_gt_f64_e64 s[6:7], |v[0:1]|, v[4:5]
+; GFX9-NEXT:    s_or_b64 s[4:5], s[4:5], vcc
+; GFX9-NEXT:    v_cmp_o_f64_e32 vcc, v[0:1], v[0:1]
+; GFX9-NEXT:    v_cndmask_b32_e64 v4, -1, 1, s[6:7]
+; GFX9-NEXT:    v_cndmask_b32_e64 v4, v4, 0, s[4:5]
+; GFX9-NEXT:    v_and_or_b32 v5, v1, s8, v4
+; GFX9-NEXT:    v_bfe_u32 v4, v4, 16, 1
+; GFX9-NEXT:    v_add3_u32 v4, v4, v5, s9
+; GFX9-NEXT:    v_lshrrev_b32_e32 v4, 16, v4
+; GFX9-NEXT:    v_cndmask_b32_e32 v0, v7, v4, vcc
+; GFX9-NEXT:    global_store_short v[2:3], v0, off
 ; GFX9-NEXT:    s_waitcnt vmcnt(0)
 ; GFX9-NEXT:    s_setpc_b64 s[30:31]
 ;
@@ -2011,8 +2069,22 @@ define void @test_load_store_f64_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
 ; GFX10-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX10-NEXT:    global_load_dwordx2 v[0:1], v[0:1], off
 ; GFX10-NEXT:    s_waitcnt vmcnt(0)
-; GFX10-NEXT:    v_cvt_f32_f64_e32 v0, v[0:1]
-; GFX10-NEXT:    global_store_short_d16_hi v[2:3], v0, off
+; GFX10-NEXT:    v_cvt_f32_f64_e64 v6, |v[0:1]|
+; GFX10-NEXT:    v_cvt_f64_f32_e32 v[4:5], v6
+; GFX10-NEXT:    v_and_b32_e32 v6, 1, v6
+; GFX10-NEXT:    v_cmp_eq_u32_e32 vcc_lo, 1, v6
+; GFX10-NEXT:    v_cmp_gt_f64_e64 s5, |v[0:1]|, v[4:5]
+; GFX10-NEXT:    v_cmp_nlg_f64_e64 s4, |v[0:1]|, v[4:5]
+; GFX10-NEXT:    v_cndmask_b32_e64 v4, -1, 1, s5
+; GFX10-NEXT:    s_or_b32 s4, s4, vcc_lo
+; GFX10-NEXT:    v_cmp_o_f64_e32 vcc_lo, v[0:1], v[0:1]
+; GFX10-NEXT:    v_cndmask_b32_e64 v4, v4, 0, s4
+; GFX10-NEXT:    v_and_or_b32 v5, 0x80000000, v1, v4
+; GFX10-NEXT:    v_bfe_u32 v4, v4, 16, 1
+; GFX10-NEXT:    v_add3_u32 v4, v4, v5, 0x7fff
+; GFX10-NEXT:    v_lshrrev_b32_e32 v4, 16, v4
+; GFX10-NEXT:    v_cndmask_b32_e32 v0, 0x7fc0, v4, vcc_lo
+; GFX10-NEXT:    global_store_short v[2:3], v0, off
 ; GFX10-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; GFX11-LABEL: test_load_store_f64_to_bf16:
@@ -2020,8 +2092,27 @@ define void @test_load_store_f64_to_bf16(ptr addrspace(1) %in, ptr addrspace(1)
 ; GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX11-NEXT:    global_load_b64 v[0:1], v[0:1], off
 ; GFX11-NEXT:    s_waitcnt vmcnt(0)
-; GFX11-NEXT:    v_cvt_f32_f64_e32 v0, v[0:1]
-; GFX11-NEXT:    global_store_d16_hi_b16 v[2:3], v0, off
+; GFX11-NEXT:    v_cvt_f32_f64_e64 v6, |v[0:1]|
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1)
+; GFX11-NEXT:    v_cvt_f64_f32_e32 v[4:5], v6
+; GFX11-NEXT:    v_and_b32_e32 v6, 1, v6
+; GFX11-NEXT:    v_cmp_eq_u32_e32 vcc_lo, 1, v6
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_3) | instskip(SKIP_1) | instid1(VALU_DEP_2)
+; GFX11-NEXT:    v_cmp_nlg_f64_e64 s0, |v[0:1]|, v[4:5]
+; GFX11-NEXT:    v_cmp_gt_f64_e64 s1, |v[0:1]|, v[4:5]
+; GFX11-NEXT:    s_or_b32 s0, s0, vcc_lo
+; GFX11-NEXT:    v_cmp_o_f64_e32 vcc_lo, v[0:1], v[0:1]
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX11-NEXT:    v_cndmask_b32_e64 v4, -1, 1, s1
+; GFX11-NEXT:    v_cndmask_b32_e64 v4, v4, 0, s0
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1)
+; GFX11-NEXT:    v_and_or_b32 v5, 0x80000000, v1, v4
+; GFX11-NEXT:    v_bfe_u32 v4, v4, 16, 1
+; GFX11-NEXT:    v_add3_u32 v4, v4, v5, 0x7fff
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX11-NEXT:    v_lshrrev_b32_e32 v4, 16, v4
+; GFX11-NEXT:    v_cndmask_b32_e32 v0, 0x7fc0, v4, vcc_lo
+; GFX11-NEXT:    global_store_b16 v[2:3], v0, off
 ; GFX11-NEXT:    s_setpc_b64 s[30:31]
   %val = load double, ptr addrspace(1) %in
   %val.bf16 = fptrunc double %val to bfloat
@@ -8487,7 +8578,13 @@ define bfloat ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/82399


More information about the llvm-commits mailing list