[llvm] [RISCV] Implement load/store support for XAndesBFHCvt (PR #150350)
Jim Lin via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 23 23:38:42 PDT 2025
https://github.com/tclin914 updated https://github.com/llvm/llvm-project/pull/150350
>From 1a8c16be7ca7755f7b2a222faf270167ba863d28 Mon Sep 17 00:00:00 2001
From: Jim Lin <jim at andestech.com>
Date: Thu, 17 Jul 2025 09:23:19 +0800
Subject: [PATCH 1/2] [RISCV] Implement load/store support for XAndesBFHCvt
We use lh to load 2 bytes from memory into a gpr, then mask this gpr with
-65536 to emulate nan-boxing behavior, and then the value in gpr is moved to
fpr using `fmv.w.x`.
To move the value back from fpr to gpr, we use `fmv.x.w` and finally, `sh`
is used to store the lower 2 bytes back to memory.
If zfh is enabled at the same time, we can just use flh/fsw to
load/store bf16 directly.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 54 +++++++++++++++++++
llvm/lib/Target/RISCV/RISCVISelLowering.h | 3 ++
llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td | 33 ++++++++++++
llvm/test/CodeGen/RISCV/xandesbfhcvt.ll | 52 +++++++++++++++++-
4 files changed, 140 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 3918dd21bc09d..8b5ae01282293 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1618,6 +1618,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
}
}
+ // Customize load and store operation for bf16 if zfh isn't enabled.
+ if (Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh()) {
+ setOperationAction(ISD::LOAD, MVT::bf16, Custom);
+ setOperationAction(ISD::STORE, MVT::bf16, Custom);
+ }
+
// Function alignments.
const Align FunctionAlignment(Subtarget.hasStdExtZca() ? 2 : 4);
setMinFunctionAlignment(FunctionAlignment);
@@ -7216,6 +7222,47 @@ static SDValue SplitStrictFPVectorOp(SDValue Op, SelectionDAG &DAG) {
return DAG.getMergeValues({V, HiRes.getValue(1)}, DL);
}
+SDValue
+RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Load(SDValue Op,
+ SelectionDAG &DAG) const {
+ assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() &&
+ "Unexpected bfloat16 load lowering");
+
+ SDLoc DL(Op);
+ LoadSDNode *LD = cast<LoadSDNode>(Op.getNode());
+ EVT MemVT = LD->getMemoryVT();
+ SDValue Load = DAG.getExtLoad(
+ ISD::ZEXTLOAD, DL, Subtarget.getXLenVT(), LD->getChain(),
+ LD->getBasePtr(),
+ EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits()),
+ LD->getMemOperand());
+ // Using mask to make bf16 nan-boxing valid when we don't have flh
+ // instruction. -65536 would be treat as a small number and thus it can be
+ // directly used lui to get the constant.
+ SDValue mask = DAG.getSignedConstant(-65536, DL, Subtarget.getXLenVT());
+ SDValue OrSixteenOne =
+ DAG.getNode(ISD::OR, DL, Load.getValueType(), {Load, mask});
+ SDValue ConvertedResult =
+ DAG.getNode(RISCVISD::NDS_FMV_BF16_X, DL, MVT::bf16, OrSixteenOne);
+ return DAG.getMergeValues({ConvertedResult, Load.getValue(1)}, DL);
+}
+
+SDValue
+RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Store(SDValue Op,
+ SelectionDAG &DAG) const {
+ assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() &&
+ "Unexpected bfloat16 store lowering");
+
+ StoreSDNode *ST = cast<StoreSDNode>(Op.getNode());
+ SDLoc DL(Op);
+ SDValue FMV = DAG.getNode(RISCVISD::NDS_FMV_X_ANYEXTBF16, DL,
+ Subtarget.getXLenVT(), ST->getValue());
+ return DAG.getTruncStore(
+ ST->getChain(), DL, FMV, ST->getBasePtr(),
+ EVT::getIntegerVT(*DAG.getContext(), ST->getMemoryVT().getSizeInBits()),
+ ST->getMemOperand());
+}
+
SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
@@ -7914,6 +7961,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return DAG.getMergeValues({Pair, Chain}, DL);
}
+ if (VT == MVT::bf16)
+ return lowerXAndesBfHCvtBFloat16Load(Op, DAG);
+
// Handle normal vector tuple load.
if (VT.isRISCVVectorTuple()) {
SDLoc DL(Op);
@@ -7998,6 +8048,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
{Store->getChain(), Lo, Hi, Store->getBasePtr()}, MVT::i64,
Store->getMemOperand());
}
+
+ if (VT == MVT::bf16)
+ return lowerXAndesBfHCvtBFloat16Store(Op, DAG);
+
// Handle normal vector tuple store.
if (VT.isRISCVVectorTuple()) {
SDLoc DL(Op);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index f0447e02191ae..ca70c46988b4e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -578,6 +578,9 @@ class RISCVTargetLowering : public TargetLowering {
SDValue lowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
+ SDValue lowerXAndesBfHCvtBFloat16Load(SDValue Op, SelectionDAG &DAG) const;
+ SDValue lowerXAndesBfHCvtBFloat16Store(SDValue Op, SelectionDAG &DAG) const;
+
bool isEligibleForTailCallOptimization(
CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
const SmallVector<CCValAssign, 16> &ArgLocs) const;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td
index 5220815336441..5d0b66aab5320 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td
@@ -10,6 +10,20 @@
//
//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+// RISC-V specific DAG Nodes.
+//===----------------------------------------------------------------------===//
+
+def SDT_NDS_FMV_BF16_X
+ : SDTypeProfile<1, 1, [SDTCisVT<0, bf16>, SDTCisVT<1, XLenVT>]>;
+def SDT_NDS_FMV_X_ANYEXTBF16
+ : SDTypeProfile<1, 1, [SDTCisVT<0, XLenVT>, SDTCisVT<1, bf16>]>;
+
+def riscv_nds_fmv_bf16_x
+ : SDNode<"RISCVISD::NDS_FMV_BF16_X", SDT_NDS_FMV_BF16_X>;
+def riscv_nds_fmv_x_anyextbf16
+ : SDNode<"RISCVISD::NDS_FMV_X_ANYEXTBF16", SDT_NDS_FMV_X_ANYEXTBF16>;
+
//===----------------------------------------------------------------------===//
// Operand and SDNode transformation definitions.
//===----------------------------------------------------------------------===//
@@ -774,6 +788,25 @@ def : Pat<(bf16 (fpround FPR32:$rs)),
(NDS_FCVT_BF16_S FPR32:$rs)>;
} // Predicates = [HasVendorXAndesBFHCvt]
+let isCodeGenOnly = 1 in {
+def NDS_FMV_BF16_X : FPUnaryOp_r<0b1111000, 0b00000, 0b000, FPR16, GPR, "fmv.w.x">,
+ Sched<[WriteFMovI32ToF32, ReadFMovI32ToF32]>;
+def NDS_FMV_X_BF16 : FPUnaryOp_r<0b1110000, 0b00000, 0b000, GPR, FPR16, "fmv.x.w">,
+ Sched<[WriteFMovF32ToI32, ReadFMovF32ToI32]>;
+}
+
+let Predicates = [HasVendorXAndesBFHCvt] in {
+def : Pat<(riscv_nds_fmv_bf16_x GPR:$src), (NDS_FMV_BF16_X GPR:$src)>;
+def : Pat<(riscv_nds_fmv_x_anyextbf16 (bf16 FPR16:$src)),
+ (NDS_FMV_X_BF16 (bf16 FPR16:$src))>;
+} // Predicates = [HasVendorXAndesBFHCvt]
+
+// Use flh/fsh to load/store bf16 if zfh is enabled.
+let Predicates = [HasStdExtZfh, HasVendorXAndesBFHCvt] in {
+def : LdPat<load, FLH, bf16>;
+def : StPat<store, FSH, FPR16, bf16>;
+} // Predicates = [HasStdExtZfh, HasVendorXAndesBFHCvt]
+
let Predicates = [HasVendorXAndesVBFHCvt] in {
defm PseudoNDS_VFWCVT_S_BF16 : VPseudoVWCVT_S_BF16;
defm PseudoNDS_VFNCVT_BF16_S : VPseudoVNCVT_BF16_S;
diff --git a/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll b/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
index 854d0b659ea73..c0c15172676fd 100644
--- a/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
+++ b/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
@@ -1,8 +1,12 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+xandesbfhcvt -target-abi ilp32f \
-; RUN: -verify-machineinstrs < %s | FileCheck %s
+; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,XANDESBFHCVT %s
+; RUN: llc -mtriple=riscv32 -mattr=+zfh,+xandesbfhcvt -target-abi ilp32f \
+; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,ZFH %s
; RUN: llc -mtriple=riscv64 -mattr=+xandesbfhcvt -target-abi lp64f \
-; RUN: -verify-machineinstrs < %s | FileCheck %s
+; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,XANDESBFHCVT %s
+; RUN: llc -mtriple=riscv64 -mattr=+zfh,+xandesbfhcvt -target-abi lp64f \
+; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,ZFH %s
define float @fcvt_s_bf16(bfloat %a) nounwind {
; CHECK-LABEL: fcvt_s_bf16:
@@ -21,3 +25,47 @@ define bfloat @fcvt_bf16_s(float %a) nounwind {
%1 = fptrunc float %a to bfloat
ret bfloat %1
}
+
+ at sf = dso_local global float 0.000000e+00, align 4
+ at bf = dso_local global bfloat 0xR0000, align 2
+
+; Check load and store to bf16.
+define void @loadstorebf16() nounwind {
+; XANDESBFHCVT-LABEL: loadstorebf16:
+; XANDESBFHCVT: # %bb.0: # %entry
+; XANDESBFHCVT-NEXT: lui a0, %hi(.L_MergedGlobals)
+; XANDESBFHCVT-NEXT: lhu a1, %lo(.L_MergedGlobals)(a0)
+; XANDESBFHCVT-NEXT: lui a2, 1048560
+; XANDESBFHCVT-NEXT: or a1, a1, a2
+; XANDESBFHCVT-NEXT: fmv.w.x fa5, a1
+; XANDESBFHCVT-NEXT: addi a1, a0, %lo(.L_MergedGlobals)
+; XANDESBFHCVT-NEXT: nds.fcvt.s.bf16 fa5, fa5
+; XANDESBFHCVT-NEXT: fsw fa5, 4(a1)
+; XANDESBFHCVT-NEXT: flw fa5, 4(a1)
+; XANDESBFHCVT-NEXT: nds.fcvt.bf16.s fa5, fa5
+; XANDESBFHCVT-NEXT: fmv.x.w a1, fa5
+; XANDESBFHCVT-NEXT: sh a1, %lo(.L_MergedGlobals)(a0)
+; XANDESBFHCVT-NEXT: ret
+;
+; ZFH-LABEL: loadstorebf16:
+; ZFH: # %bb.0: # %entry
+; ZFH-NEXT: lui a0, %hi(.L_MergedGlobals)
+; ZFH-NEXT: flh fa5, %lo(.L_MergedGlobals)(a0)
+; ZFH-NEXT: addi a1, a0, %lo(.L_MergedGlobals)
+; ZFH-NEXT: nds.fcvt.s.bf16 fa5, fa5
+; ZFH-NEXT: fsw fa5, 4(a1)
+; ZFH-NEXT: flw fa5, 4(a1)
+; ZFH-NEXT: nds.fcvt.bf16.s fa5, fa5
+; ZFH-NEXT: fsh fa5, %lo(.L_MergedGlobals)(a0)
+; ZFH-NEXT: ret
+entry:
+ %0 = load bfloat, bfloat* @bf, align 2
+ %1 = fpext bfloat %0 to float
+ store volatile float %1, float* @sf, align 4
+
+ %2 = load float, float* @sf, align 4
+ %3 = fptrunc float %2 to bfloat
+ store volatile bfloat %3, bfloat* @bf, align 2
+
+ ret void
+}
>From 11dc0c3a1058e637dedec97e87cf1c44a5cd6f78 Mon Sep 17 00:00:00 2001
From: Jim Lin <jim at andestech.com>
Date: Thu, 24 Jul 2025 14:36:44 +0800
Subject: [PATCH 2/2] Load/store bf16 from/to the address passed from argument
instead of global variables
---
llvm/test/CodeGen/RISCV/xandesbfhcvt.ll | 39 ++++++++++---------------
1 file changed, 16 insertions(+), 23 deletions(-)
diff --git a/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll b/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
index c0c15172676fd..72242f1dd312d 100644
--- a/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
+++ b/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
@@ -26,46 +26,39 @@ define bfloat @fcvt_bf16_s(float %a) nounwind {
ret bfloat %1
}
- at sf = dso_local global float 0.000000e+00, align 4
- at bf = dso_local global bfloat 0xR0000, align 2
-
; Check load and store to bf16.
-define void @loadstorebf16() nounwind {
+define void @loadstorebf16(ptr %bf, ptr %sf) nounwind {
; XANDESBFHCVT-LABEL: loadstorebf16:
; XANDESBFHCVT: # %bb.0: # %entry
-; XANDESBFHCVT-NEXT: lui a0, %hi(.L_MergedGlobals)
-; XANDESBFHCVT-NEXT: lhu a1, %lo(.L_MergedGlobals)(a0)
-; XANDESBFHCVT-NEXT: lui a2, 1048560
-; XANDESBFHCVT-NEXT: or a1, a1, a2
-; XANDESBFHCVT-NEXT: fmv.w.x fa5, a1
-; XANDESBFHCVT-NEXT: addi a1, a0, %lo(.L_MergedGlobals)
+; XANDESBFHCVT-NEXT: lhu a2, 0(a0)
+; XANDESBFHCVT-NEXT: lui a3, 1048560
+; XANDESBFHCVT-NEXT: or a2, a2, a3
+; XANDESBFHCVT-NEXT: fmv.w.x fa5, a2
; XANDESBFHCVT-NEXT: nds.fcvt.s.bf16 fa5, fa5
-; XANDESBFHCVT-NEXT: fsw fa5, 4(a1)
-; XANDESBFHCVT-NEXT: flw fa5, 4(a1)
+; XANDESBFHCVT-NEXT: fsw fa5, 0(a1)
+; XANDESBFHCVT-NEXT: flw fa5, 0(a1)
; XANDESBFHCVT-NEXT: nds.fcvt.bf16.s fa5, fa5
; XANDESBFHCVT-NEXT: fmv.x.w a1, fa5
-; XANDESBFHCVT-NEXT: sh a1, %lo(.L_MergedGlobals)(a0)
+; XANDESBFHCVT-NEXT: sh a1, 0(a0)
; XANDESBFHCVT-NEXT: ret
;
; ZFH-LABEL: loadstorebf16:
; ZFH: # %bb.0: # %entry
-; ZFH-NEXT: lui a0, %hi(.L_MergedGlobals)
-; ZFH-NEXT: flh fa5, %lo(.L_MergedGlobals)(a0)
-; ZFH-NEXT: addi a1, a0, %lo(.L_MergedGlobals)
+; ZFH-NEXT: flh fa5, 0(a0)
; ZFH-NEXT: nds.fcvt.s.bf16 fa5, fa5
-; ZFH-NEXT: fsw fa5, 4(a1)
-; ZFH-NEXT: flw fa5, 4(a1)
+; ZFH-NEXT: fsw fa5, 0(a1)
+; ZFH-NEXT: flw fa5, 0(a1)
; ZFH-NEXT: nds.fcvt.bf16.s fa5, fa5
-; ZFH-NEXT: fsh fa5, %lo(.L_MergedGlobals)(a0)
+; ZFH-NEXT: fsh fa5, 0(a0)
; ZFH-NEXT: ret
entry:
- %0 = load bfloat, bfloat* @bf, align 2
+ %0 = load bfloat, bfloat* %bf, align 2
%1 = fpext bfloat %0 to float
- store volatile float %1, float* @sf, align 4
+ store volatile float %1, float* %sf, align 4
- %2 = load float, float* @sf, align 4
+ %2 = load float, float* %sf, align 4
%3 = fptrunc float %2 to bfloat
- store volatile bfloat %3, bfloat* @bf, align 2
+ store volatile bfloat %3, bfloat* %bf, align 2
ret void
}
More information about the llvm-commits
mailing list