[llvm] [RISCV] Implement load/store support for XAndesBFHCvt (PR #150350)

Jim Lin via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 23 23:38:42 PDT 2025


https://github.com/tclin914 updated https://github.com/llvm/llvm-project/pull/150350

>From 1a8c16be7ca7755f7b2a222faf270167ba863d28 Mon Sep 17 00:00:00 2001
From: Jim Lin <jim at andestech.com>
Date: Thu, 17 Jul 2025 09:23:19 +0800
Subject: [PATCH 1/2] [RISCV] Implement load/store support for XAndesBFHCvt

We use lh to load 2 bytes from memory into a gpr, then mask this gpr with
-65536 to emulate nan-boxing behavior, and then the value in gpr is moved to
fpr using `fmv.w.x`.
To move the value back from fpr to gpr, we use `fmv.x.w` and finally, `sh`
is used to store the lower 2 bytes back to memory.

If zfh is enabled at the same time, we can just use flh/fsw to
load/store bf16 directly.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 54 +++++++++++++++++++
 llvm/lib/Target/RISCV/RISCVISelLowering.h     |  3 ++
 llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td | 33 ++++++++++++
 llvm/test/CodeGen/RISCV/xandesbfhcvt.ll       | 52 +++++++++++++++++-
 4 files changed, 140 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 3918dd21bc09d..8b5ae01282293 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1618,6 +1618,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     }
   }
 
+  // Customize load and store operation for bf16 if zfh isn't enabled.
+  if (Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh()) {
+    setOperationAction(ISD::LOAD, MVT::bf16, Custom);
+    setOperationAction(ISD::STORE, MVT::bf16, Custom);
+  }
+
   // Function alignments.
   const Align FunctionAlignment(Subtarget.hasStdExtZca() ? 2 : 4);
   setMinFunctionAlignment(FunctionAlignment);
@@ -7216,6 +7222,47 @@ static SDValue SplitStrictFPVectorOp(SDValue Op, SelectionDAG &DAG) {
   return DAG.getMergeValues({V, HiRes.getValue(1)}, DL);
 }
 
+SDValue
+RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Load(SDValue Op,
+                                                   SelectionDAG &DAG) const {
+  assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() &&
+         "Unexpected bfloat16 load lowering");
+
+  SDLoc DL(Op);
+  LoadSDNode *LD = cast<LoadSDNode>(Op.getNode());
+  EVT MemVT = LD->getMemoryVT();
+  SDValue Load = DAG.getExtLoad(
+      ISD::ZEXTLOAD, DL, Subtarget.getXLenVT(), LD->getChain(),
+      LD->getBasePtr(),
+      EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits()),
+      LD->getMemOperand());
+  // Using mask to make bf16 nan-boxing valid when we don't have flh
+  // instruction. -65536 would be treat as a small number and thus it can be
+  // directly used lui to get the constant.
+  SDValue mask = DAG.getSignedConstant(-65536, DL, Subtarget.getXLenVT());
+  SDValue OrSixteenOne =
+      DAG.getNode(ISD::OR, DL, Load.getValueType(), {Load, mask});
+  SDValue ConvertedResult =
+      DAG.getNode(RISCVISD::NDS_FMV_BF16_X, DL, MVT::bf16, OrSixteenOne);
+  return DAG.getMergeValues({ConvertedResult, Load.getValue(1)}, DL);
+}
+
+SDValue
+RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Store(SDValue Op,
+                                                    SelectionDAG &DAG) const {
+  assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() &&
+         "Unexpected bfloat16 store lowering");
+
+  StoreSDNode *ST = cast<StoreSDNode>(Op.getNode());
+  SDLoc DL(Op);
+  SDValue FMV = DAG.getNode(RISCVISD::NDS_FMV_X_ANYEXTBF16, DL,
+                            Subtarget.getXLenVT(), ST->getValue());
+  return DAG.getTruncStore(
+      ST->getChain(), DL, FMV, ST->getBasePtr(),
+      EVT::getIntegerVT(*DAG.getContext(), ST->getMemoryVT().getSizeInBits()),
+      ST->getMemOperand());
+}
+
 SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
                                             SelectionDAG &DAG) const {
   switch (Op.getOpcode()) {
@@ -7914,6 +7961,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
       return DAG.getMergeValues({Pair, Chain}, DL);
     }
 
+    if (VT == MVT::bf16)
+      return lowerXAndesBfHCvtBFloat16Load(Op, DAG);
+
     // Handle normal vector tuple load.
     if (VT.isRISCVVectorTuple()) {
       SDLoc DL(Op);
@@ -7998,6 +8048,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
           {Store->getChain(), Lo, Hi, Store->getBasePtr()}, MVT::i64,
           Store->getMemOperand());
     }
+
+    if (VT == MVT::bf16)
+      return lowerXAndesBfHCvtBFloat16Store(Op, DAG);
+
     // Handle normal vector tuple store.
     if (VT.isRISCVVectorTuple()) {
       SDLoc DL(Op);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index f0447e02191ae..ca70c46988b4e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -578,6 +578,9 @@ class RISCVTargetLowering : public TargetLowering {
   SDValue lowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue lowerXAndesBfHCvtBFloat16Load(SDValue Op, SelectionDAG &DAG) const;
+  SDValue lowerXAndesBfHCvtBFloat16Store(SDValue Op, SelectionDAG &DAG) const;
+
   bool isEligibleForTailCallOptimization(
       CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
       const SmallVector<CCValAssign, 16> &ArgLocs) const;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td
index 5220815336441..5d0b66aab5320 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td
@@ -10,6 +10,20 @@
 //
 //===----------------------------------------------------------------------===//
 
+//===----------------------------------------------------------------------===//
+// RISC-V specific DAG Nodes.
+//===----------------------------------------------------------------------===//
+
+def SDT_NDS_FMV_BF16_X
+    : SDTypeProfile<1, 1, [SDTCisVT<0, bf16>, SDTCisVT<1, XLenVT>]>;
+def SDT_NDS_FMV_X_ANYEXTBF16
+    : SDTypeProfile<1, 1, [SDTCisVT<0, XLenVT>, SDTCisVT<1, bf16>]>;
+
+def riscv_nds_fmv_bf16_x
+    : SDNode<"RISCVISD::NDS_FMV_BF16_X", SDT_NDS_FMV_BF16_X>;
+def riscv_nds_fmv_x_anyextbf16
+    : SDNode<"RISCVISD::NDS_FMV_X_ANYEXTBF16", SDT_NDS_FMV_X_ANYEXTBF16>;
+
 //===----------------------------------------------------------------------===//
 // Operand and SDNode transformation definitions.
 //===----------------------------------------------------------------------===//
@@ -774,6 +788,25 @@ def : Pat<(bf16 (fpround FPR32:$rs)),
           (NDS_FCVT_BF16_S FPR32:$rs)>;
 } // Predicates = [HasVendorXAndesBFHCvt]
 
+let isCodeGenOnly = 1 in {
+def NDS_FMV_BF16_X : FPUnaryOp_r<0b1111000, 0b00000, 0b000, FPR16, GPR, "fmv.w.x">,
+                     Sched<[WriteFMovI32ToF32, ReadFMovI32ToF32]>;
+def NDS_FMV_X_BF16 : FPUnaryOp_r<0b1110000, 0b00000, 0b000, GPR, FPR16, "fmv.x.w">,
+                     Sched<[WriteFMovF32ToI32, ReadFMovF32ToI32]>;
+}
+
+let Predicates = [HasVendorXAndesBFHCvt] in {
+def : Pat<(riscv_nds_fmv_bf16_x GPR:$src), (NDS_FMV_BF16_X GPR:$src)>;
+def : Pat<(riscv_nds_fmv_x_anyextbf16 (bf16 FPR16:$src)),
+          (NDS_FMV_X_BF16 (bf16 FPR16:$src))>;
+} // Predicates = [HasVendorXAndesBFHCvt]
+
+// Use flh/fsh to load/store bf16 if zfh is enabled.
+let Predicates = [HasStdExtZfh, HasVendorXAndesBFHCvt] in {
+def : LdPat<load, FLH, bf16>;
+def : StPat<store, FSH, FPR16, bf16>;
+} // Predicates = [HasStdExtZfh, HasVendorXAndesBFHCvt]
+
 let Predicates = [HasVendorXAndesVBFHCvt] in {
 defm PseudoNDS_VFWCVT_S_BF16 : VPseudoVWCVT_S_BF16;
 defm PseudoNDS_VFNCVT_BF16_S : VPseudoVNCVT_BF16_S;
diff --git a/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll b/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
index 854d0b659ea73..c0c15172676fd 100644
--- a/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
+++ b/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
@@ -1,8 +1,12 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
 ; RUN: llc -mtriple=riscv32 -mattr=+xandesbfhcvt -target-abi ilp32f \
-; RUN:   -verify-machineinstrs < %s | FileCheck %s
+; RUN:   -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,XANDESBFHCVT %s
+; RUN: llc -mtriple=riscv32 -mattr=+zfh,+xandesbfhcvt -target-abi ilp32f \
+; RUN:   -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,ZFH %s
 ; RUN: llc -mtriple=riscv64 -mattr=+xandesbfhcvt -target-abi lp64f \
-; RUN:   -verify-machineinstrs < %s | FileCheck %s
+; RUN:   -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,XANDESBFHCVT %s
+; RUN: llc -mtriple=riscv64 -mattr=+zfh,+xandesbfhcvt -target-abi lp64f \
+; RUN:   -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,ZFH %s
 
 define float @fcvt_s_bf16(bfloat %a) nounwind {
 ; CHECK-LABEL: fcvt_s_bf16:
@@ -21,3 +25,47 @@ define bfloat @fcvt_bf16_s(float %a) nounwind {
   %1 = fptrunc float %a to bfloat
   ret bfloat %1
 }
+
+ at sf = dso_local global float 0.000000e+00, align 4
+ at bf = dso_local global bfloat 0xR0000, align 2
+
+; Check load and store to bf16.
+define void @loadstorebf16() nounwind {
+; XANDESBFHCVT-LABEL: loadstorebf16:
+; XANDESBFHCVT:       # %bb.0: # %entry
+; XANDESBFHCVT-NEXT:    lui a0, %hi(.L_MergedGlobals)
+; XANDESBFHCVT-NEXT:    lhu a1, %lo(.L_MergedGlobals)(a0)
+; XANDESBFHCVT-NEXT:    lui a2, 1048560
+; XANDESBFHCVT-NEXT:    or a1, a1, a2
+; XANDESBFHCVT-NEXT:    fmv.w.x fa5, a1
+; XANDESBFHCVT-NEXT:    addi a1, a0, %lo(.L_MergedGlobals)
+; XANDESBFHCVT-NEXT:    nds.fcvt.s.bf16 fa5, fa5
+; XANDESBFHCVT-NEXT:    fsw fa5, 4(a1)
+; XANDESBFHCVT-NEXT:    flw fa5, 4(a1)
+; XANDESBFHCVT-NEXT:    nds.fcvt.bf16.s fa5, fa5
+; XANDESBFHCVT-NEXT:    fmv.x.w a1, fa5
+; XANDESBFHCVT-NEXT:    sh a1, %lo(.L_MergedGlobals)(a0)
+; XANDESBFHCVT-NEXT:    ret
+;
+; ZFH-LABEL: loadstorebf16:
+; ZFH:       # %bb.0: # %entry
+; ZFH-NEXT:    lui a0, %hi(.L_MergedGlobals)
+; ZFH-NEXT:    flh fa5, %lo(.L_MergedGlobals)(a0)
+; ZFH-NEXT:    addi a1, a0, %lo(.L_MergedGlobals)
+; ZFH-NEXT:    nds.fcvt.s.bf16 fa5, fa5
+; ZFH-NEXT:    fsw fa5, 4(a1)
+; ZFH-NEXT:    flw fa5, 4(a1)
+; ZFH-NEXT:    nds.fcvt.bf16.s fa5, fa5
+; ZFH-NEXT:    fsh fa5, %lo(.L_MergedGlobals)(a0)
+; ZFH-NEXT:    ret
+entry:
+  %0 = load bfloat, bfloat* @bf, align 2
+  %1 = fpext bfloat %0 to float
+  store volatile float %1, float* @sf, align 4
+
+  %2 = load float, float* @sf, align 4
+  %3 = fptrunc float %2 to bfloat
+  store volatile bfloat %3, bfloat* @bf, align 2
+
+  ret void
+}

>From 11dc0c3a1058e637dedec97e87cf1c44a5cd6f78 Mon Sep 17 00:00:00 2001
From: Jim Lin <jim at andestech.com>
Date: Thu, 24 Jul 2025 14:36:44 +0800
Subject: [PATCH 2/2] Load/store bf16 from/to the address passed from argument
 instead of global variables

---
 llvm/test/CodeGen/RISCV/xandesbfhcvt.ll | 39 ++++++++++---------------
 1 file changed, 16 insertions(+), 23 deletions(-)

diff --git a/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll b/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
index c0c15172676fd..72242f1dd312d 100644
--- a/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
+++ b/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
@@ -26,46 +26,39 @@ define bfloat @fcvt_bf16_s(float %a) nounwind {
   ret bfloat %1
 }
 
- at sf = dso_local global float 0.000000e+00, align 4
- at bf = dso_local global bfloat 0xR0000, align 2
-
 ; Check load and store to bf16.
-define void @loadstorebf16() nounwind {
+define void @loadstorebf16(ptr %bf, ptr %sf) nounwind {
 ; XANDESBFHCVT-LABEL: loadstorebf16:
 ; XANDESBFHCVT:       # %bb.0: # %entry
-; XANDESBFHCVT-NEXT:    lui a0, %hi(.L_MergedGlobals)
-; XANDESBFHCVT-NEXT:    lhu a1, %lo(.L_MergedGlobals)(a0)
-; XANDESBFHCVT-NEXT:    lui a2, 1048560
-; XANDESBFHCVT-NEXT:    or a1, a1, a2
-; XANDESBFHCVT-NEXT:    fmv.w.x fa5, a1
-; XANDESBFHCVT-NEXT:    addi a1, a0, %lo(.L_MergedGlobals)
+; XANDESBFHCVT-NEXT:    lhu a2, 0(a0)
+; XANDESBFHCVT-NEXT:    lui a3, 1048560
+; XANDESBFHCVT-NEXT:    or a2, a2, a3
+; XANDESBFHCVT-NEXT:    fmv.w.x fa5, a2
 ; XANDESBFHCVT-NEXT:    nds.fcvt.s.bf16 fa5, fa5
-; XANDESBFHCVT-NEXT:    fsw fa5, 4(a1)
-; XANDESBFHCVT-NEXT:    flw fa5, 4(a1)
+; XANDESBFHCVT-NEXT:    fsw fa5, 0(a1)
+; XANDESBFHCVT-NEXT:    flw fa5, 0(a1)
 ; XANDESBFHCVT-NEXT:    nds.fcvt.bf16.s fa5, fa5
 ; XANDESBFHCVT-NEXT:    fmv.x.w a1, fa5
-; XANDESBFHCVT-NEXT:    sh a1, %lo(.L_MergedGlobals)(a0)
+; XANDESBFHCVT-NEXT:    sh a1, 0(a0)
 ; XANDESBFHCVT-NEXT:    ret
 ;
 ; ZFH-LABEL: loadstorebf16:
 ; ZFH:       # %bb.0: # %entry
-; ZFH-NEXT:    lui a0, %hi(.L_MergedGlobals)
-; ZFH-NEXT:    flh fa5, %lo(.L_MergedGlobals)(a0)
-; ZFH-NEXT:    addi a1, a0, %lo(.L_MergedGlobals)
+; ZFH-NEXT:    flh fa5, 0(a0)
 ; ZFH-NEXT:    nds.fcvt.s.bf16 fa5, fa5
-; ZFH-NEXT:    fsw fa5, 4(a1)
-; ZFH-NEXT:    flw fa5, 4(a1)
+; ZFH-NEXT:    fsw fa5, 0(a1)
+; ZFH-NEXT:    flw fa5, 0(a1)
 ; ZFH-NEXT:    nds.fcvt.bf16.s fa5, fa5
-; ZFH-NEXT:    fsh fa5, %lo(.L_MergedGlobals)(a0)
+; ZFH-NEXT:    fsh fa5, 0(a0)
 ; ZFH-NEXT:    ret
 entry:
-  %0 = load bfloat, bfloat* @bf, align 2
+  %0 = load bfloat, bfloat* %bf, align 2
   %1 = fpext bfloat %0 to float
-  store volatile float %1, float* @sf, align 4
+  store volatile float %1, float* %sf, align 4
 
-  %2 = load float, float* @sf, align 4
+  %2 = load float, float* %sf, align 4
   %3 = fptrunc float %2 to bfloat
-  store volatile bfloat %3, bfloat* @bf, align 2
+  store volatile bfloat %3, bfloat* %bf, align 2
 
   ret void
 }



More information about the llvm-commits mailing list