[llvm] [RISCV] Separate more of scalar FP in CC_RISCV. (PR #107908)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 9 13:05:00 PDT 2024
https://github.com/topperc created https://github.com/llvm/llvm-project/pull/107908
Scalar FP calling convention has gotten more complicated with recent changes to Zfinx/Zdinx, proposed addition of a GPRF16 register class, and using customReg for f16/bf16 and other FP types small than XLen.
The previous code tried to share a single getReg and getMem call for many different cases. This patch separates all the FP register handling to the top of the function with their own getReg calls. The only exception is f64 with XLen==32, when we are out of FPRs or not able to use FPRs due to ABI.
The way I've structured this, we no longer need to correct the LocVT for FP back to ValVT before the call to getMem.
>From 7a386deefc9640eaf55223f19b38de631668ac48 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Sat, 7 Sep 2024 11:05:22 -0700
Subject: [PATCH] [RISCV] Separate more of scalar FP in CC_RISCV.
Scalar FP calling convention has gotten more complicated with recent
changes to Zfinx/Zdinx, proposed addition of a GPRF16 register class,
and using customReg for f16/bf16 and other FP types small than XLen.
The previous code tried to share a single getReg and getMem call
for many different cases. This patch separates all the FP register
handling to the top of the function with their own getReg calls.
The only exception is f64 with XLen==32, when we are out of FPRs or
not able to use FPRs due to ABI.
The way I've structured this, we no longer need to correct the LocVT
for FP back to ValVT before the call to getMem.
---
llvm/lib/Target/RISCV/RISCVCallingConv.cpp | 72 +++++++++++-----------
1 file changed, 36 insertions(+), 36 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVCallingConv.cpp b/llvm/lib/Target/RISCV/RISCVCallingConv.cpp
index 7c70f1f8d1ff8d..deba85946be53a 100644
--- a/llvm/lib/Target/RISCV/RISCVCallingConv.cpp
+++ b/llvm/lib/Target/RISCV/RISCVCallingConv.cpp
@@ -299,30 +299,42 @@ bool llvm::CC_RISCV(unsigned ValNo, MVT ValVT, MVT LocVT,
break;
}
- // FPR16, FPR32, and FPR64 alias each other.
- if (State.getFirstUnallocated(ArgFPR32s) == std::size(ArgFPR32s)) {
- UseGPRForF16_F32 = true;
- UseGPRForF64 = true;
+ if ((LocVT == MVT::f16 || LocVT == MVT::bf16) && !UseGPRForF16_F32) {
+ if (MCRegister Reg = State.AllocateReg(ArgFPR16s)) {
+ State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
+ return false;
+ }
}
- // From this point on, rely on UseGPRForF16_F32, UseGPRForF64 and
- // similar local variables rather than directly checking against the target
- // ABI.
+ if (LocVT == MVT::f32 && !UseGPRForF16_F32) {
+ if (MCRegister Reg = State.AllocateReg(ArgFPR32s)) {
+ State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
+ return false;
+ }
+ }
+
+ if (LocVT == MVT::f64 && !UseGPRForF64) {
+ if (MCRegister Reg = State.AllocateReg(ArgFPR64s)) {
+ State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
+ return false;
+ }
+ }
ArrayRef<MCPhysReg> ArgGPRs = RISCV::getArgGPRs(ABI);
- if ((ValVT == MVT::f32 && XLen == 32 && Subtarget.hasStdExtZfinx()) ||
- (ValVT == MVT::f64 && XLen == 64 && Subtarget.hasStdExtZdinx())) {
+ // Zfinx/Zdinx use GPR without a bitcast when possible.
+ if ((LocVT == MVT::f32 && XLen == 32 && Subtarget.hasStdExtZfinx()) ||
+ (LocVT == MVT::f64 && XLen == 64 && Subtarget.hasStdExtZdinx())) {
if (MCRegister Reg = State.AllocateReg(ArgGPRs)) {
State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
return false;
}
}
- if (UseGPRForF16_F32 && (ValVT == MVT::f16 || ValVT == MVT::bf16 ||
- (ValVT == MVT::f32 && XLen == 64))) {
- MCRegister Reg = State.AllocateReg(ArgGPRs);
- if (Reg) {
+ // FP smaller than XLen, uses custom GPR.
+ if (LocVT == MVT::f16 || LocVT == MVT::bf16 ||
+ (LocVT == MVT::f32 && XLen == 64)) {
+ if (MCRegister Reg = State.AllocateReg(ArgGPRs)) {
LocVT = XLenVT;
State.addLoc(
CCValAssign::getCustomReg(ValNo, ValVT, Reg, LocVT, LocInfo));
@@ -330,13 +342,14 @@ bool llvm::CC_RISCV(unsigned ValNo, MVT ValVT, MVT LocVT,
}
}
- if (UseGPRForF16_F32 &&
- (ValVT == MVT::f16 || ValVT == MVT::bf16 || ValVT == MVT::f32)) {
- LocVT = XLenVT;
- LocInfo = CCValAssign::BCvt;
- } else if (UseGPRForF64 && XLen == 64 && ValVT == MVT::f64) {
- LocVT = MVT::i64;
- LocInfo = CCValAssign::BCvt;
+ // Bitcast FP to GPR if we can use a GPR register.
+ if ((XLen == 32 && LocVT == MVT::f32) || (XLen == 64 && LocVT == MVT::f64)) {
+ if (MCRegister Reg = State.AllocateReg(ArgGPRs)) {
+ LocVT = XLenVT;
+ LocInfo = CCValAssign::BCvt;
+ State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
+ return false;
+ }
}
// If this is a variadic argument, the RISC-V calling convention requires
@@ -368,7 +381,7 @@ bool llvm::CC_RISCV(unsigned ValNo, MVT ValVT, MVT LocVT,
// Handle passing f64 on RV32D with a soft float ABI or when floating point
// registers are exhausted.
- if (UseGPRForF64 && XLen == 32 && ValVT == MVT::f64) {
+ if (XLen == 32 && LocVT == MVT::f64) {
assert(PendingLocs.empty() && "Can't lower f64 if it is split");
// Depending on available argument GPRS, f64 may be passed in a pair of
// GPRs, split between a GPR and the stack, or passed completely on the
@@ -430,13 +443,7 @@ bool llvm::CC_RISCV(unsigned ValNo, MVT ValVT, MVT LocVT,
unsigned StoreSizeBytes = XLen / 8;
Align StackAlign = Align(XLen / 8);
- if ((ValVT == MVT::f16 || ValVT == MVT::bf16) && !UseGPRForF16_F32)
- Reg = State.AllocateReg(ArgFPR16s);
- else if (ValVT == MVT::f32 && !UseGPRForF16_F32)
- Reg = State.AllocateReg(ArgFPR32s);
- else if (ValVT == MVT::f64 && !UseGPRForF64)
- Reg = State.AllocateReg(ArgFPR64s);
- else if (ValVT.isVector() || ValVT.isRISCVVectorTuple()) {
+ if (ValVT.isVector() || ValVT.isRISCVVectorTuple()) {
Reg = allocateRVVReg(ValVT, ValNo, State, TLI);
if (Reg) {
// Fixed-length vectors are located in the corresponding scalable-vector
@@ -489,7 +496,7 @@ bool llvm::CC_RISCV(unsigned ValNo, MVT ValVT, MVT LocVT,
return false;
}
- assert((!UseGPRForF16_F32 || !UseGPRForF64 || LocVT == XLenVT ||
+ assert(((ValVT.isFloatingPoint() && !ValVT.isVector()) || LocVT == XLenVT ||
(TLI.getSubtarget().hasVInstructions() &&
(ValVT.isVector() || ValVT.isRISCVVectorTuple()))) &&
"Expected an XLenVT or vector types at this stage");
@@ -499,13 +506,6 @@ bool llvm::CC_RISCV(unsigned ValNo, MVT ValVT, MVT LocVT,
return false;
}
- // When a scalar floating-point value is passed on the stack, no
- // bit-conversion is needed.
- if (ValVT.isFloatingPoint() && LocInfo != CCValAssign::Indirect) {
- assert(!ValVT.isVector());
- LocVT = ValVT;
- LocInfo = CCValAssign::Full;
- }
State.addLoc(CCValAssign::getMem(ValNo, ValVT, StackOffset, LocVT, LocInfo));
return false;
}
More information about the llvm-commits
mailing list