[llvm] [RISCV] Handle RVV return type in calling convention correctly (PR #87736)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 4 20:24:22 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Brandon Wu (4vtomat)
<details>
<summary>Changes</summary>
Return values are handled in a same way as function arguments.
One thing to mention is that if a type can be broken down into homogeneous
vector types, e.g. {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}},
it is considered as a vector tuple type and need to be handled by tuple
type rule.
---
Full diff: https://github.com/llvm/llvm-project/pull/87736.diff
5 Files Affected:
- (modified) llvm/lib/CodeGen/TargetLoweringBase.cpp (+10-2)
- (modified) llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp (+6-4)
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+89-9)
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+6-3)
- (modified) llvm/test/CodeGen/RISCV/rvv/calling-conv.ll (+116)
``````````diff
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index f64ded4f2cf965..6e7b67ded23c84 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -1809,8 +1809,16 @@ void llvm::GetReturnInfo(CallingConv::ID CC, Type *ReturnType,
else if (attr.hasRetAttr(Attribute::ZExt))
Flags.setZExt();
- for (unsigned i = 0; i < NumParts; ++i)
- Outs.push_back(ISD::OutputArg(Flags, PartVT, VT, /*isfixed=*/true, 0, 0));
+ for (unsigned i = 0; i < NumParts; ++i) {
+ ISD::ArgFlagsTy OutFlags = Flags;
+ if (NumParts > 1 && i == 0)
+ OutFlags.setSplit();
+ else if (i == NumParts - 1 && i != 0)
+ OutFlags.setSplitEnd();
+
+ Outs.push_back(
+ ISD::OutputArg(OutFlags, PartVT, VT, /*isfixed=*/true, 0, 0));
+ }
}
}
diff --git a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
index 8af4bc658409d4..c18892ac62f247 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
@@ -409,7 +409,7 @@ bool RISCVCallLowering::lowerReturnVal(MachineIRBuilder &MIRBuilder,
splitToValueTypes(OrigRetInfo, SplitRetInfos, DL, CC);
RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
- F.getReturnType()};
+ ArrayRef(F.getReturnType())};
RISCVOutgoingValueAssigner Assigner(
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
/*IsRet=*/true, Dispatcher);
@@ -538,7 +538,8 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
++Index;
}
- RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
+ RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
+ ArrayRef(TypeList)};
RISCVIncomingValueAssigner Assigner(
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
/*IsRet=*/false, Dispatcher);
@@ -603,7 +604,8 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
Call.addRegMask(TRI->getCallPreservedMask(MF, Info.CallConv));
- RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
+ RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(),
+ ArrayRef(TypeList)};
RISCVOutgoingValueAssigner ArgAssigner(
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
/*IsRet=*/false, ArgDispatcher);
@@ -635,7 +637,7 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
splitToValueTypes(Info.OrigRet, SplitRetInfos, DL, CC);
RVVArgDispatcher RetDispatcher{&MF, getTLI<RISCVTargetLowering>(),
- F.getReturnType()};
+ ArrayRef(F.getReturnType())};
RISCVIncomingValueAssigner RetAssigner(
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
/*IsRet=*/true, RetDispatcher);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 279d8a435a04ca..fb5027406b0808 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18327,13 +18327,15 @@ void RISCVTargetLowering::analyzeInputArgs(
unsigned NumArgs = Ins.size();
FunctionType *FType = MF.getFunction().getFunctionType();
- SmallVector<Type *, 4> TypeList;
- if (IsRet)
- TypeList.push_back(MF.getFunction().getReturnType());
- else
+ RVVArgDispatcher Dispatcher;
+ if (IsRet) {
+ Dispatcher = RVVArgDispatcher{&MF, this, ArrayRef(Ins)};
+ } else {
+ SmallVector<Type *, 4> TypeList;
for (const Argument &Arg : MF.getFunction().args())
TypeList.push_back(Arg.getType());
- RVVArgDispatcher Dispatcher{&MF, this, TypeList};
+ Dispatcher = RVVArgDispatcher{&MF, this, ArrayRef(TypeList)};
+ }
for (unsigned i = 0; i != NumArgs; ++i) {
MVT ArgVT = Ins[i].VT;
@@ -18368,7 +18370,7 @@ void RISCVTargetLowering::analyzeOutputArgs(
else if (CLI)
for (const TargetLowering::ArgListEntry &Arg : CLI->getArgs())
TypeList.push_back(Arg.Ty);
- RVVArgDispatcher Dispatcher{&MF, this, TypeList};
+ RVVArgDispatcher Dispatcher{&MF, this, ArrayRef(TypeList)};
for (unsigned i = 0; i != NumArgs; i++) {
MVT ArgVT = Outs[i].VT;
@@ -19272,7 +19274,7 @@ bool RISCVTargetLowering::CanLowerReturn(
SmallVector<CCValAssign, 16> RVLocs;
CCState CCInfo(CallConv, IsVarArg, MF, RVLocs, Context);
- RVVArgDispatcher Dispatcher{&MF, this, MF.getFunction().getReturnType()};
+ RVVArgDispatcher Dispatcher{&MF, this, ArrayRef(Outs)};
for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
MVT VT = Outs[i].VT;
@@ -21077,7 +21079,82 @@ unsigned RISCVTargetLowering::getMinimumJumpTableEntries() const {
return Subtarget.getMinimumJumpTableEntries();
}
-void RVVArgDispatcher::constructArgInfos(ArrayRef<Type *> TypeList) {
+// Handle single arg such as return value.
+template <typename Arg>
+void RVVArgDispatcher::constructArgInfos(ArrayRef<Arg> ArgList) {
+ // This lambda determines whether an array of types are constructed by
+ // homogeneous vector types.
+ auto isHomogeneousScalableVectorType = [&](ArrayRef<Arg> ArgList) {
+ // First, extract the first element in the argument type.
+ MVT FirstArgRegType;
+ unsigned FirstArgElements = 0;
+ typename SmallVectorImpl<Arg>::const_iterator It;
+ for (It = ArgList.begin(); It != ArgList.end(); ++It) {
+ FirstArgRegType = It->VT;
+ ++FirstArgElements;
+ if (!It->Flags.isSplit() || It->Flags.isSplitEnd())
+ break;
+ }
+ ++It;
+
+ // Return if this argument type contains only 1 element, or it's not a
+ // vector type.
+ if (It == ArgList.end() || !FirstArgRegType.isScalableVector())
+ return false;
+
+ // Second, check if the following elements in this argument type are all the
+ // same.
+ MVT ArgRegType;
+ unsigned ArgElements = 0;
+ bool IsPart = false;
+ for (; It != ArgList.end(); ++It) {
+ ArgRegType = It->VT;
+ ++ArgElements;
+ if ((!It->Flags.isSplit() && !IsPart) || It->Flags.isSplitEnd()) {
+ if (ArgRegType != FirstArgRegType || ArgElements != FirstArgElements)
+ return false;
+
+ IsPart = false;
+ ArgElements = 0;
+ continue;
+ }
+
+ IsPart = true;
+ }
+
+ return true;
+ };
+
+ if (isHomogeneousScalableVectorType(ArgList)) {
+ // Handle as tuple type
+ RVVArgInfos.push_back({(unsigned)ArgList.size(), ArgList[0].VT, false});
+ } else {
+ // Handle as normal vector type
+ bool FirstVMaskAssigned = false;
+ for (const auto &OutArg : ArgList) {
+ MVT RegisterVT = OutArg.VT;
+
+ // Skip non-RVV register type
+ if (!RegisterVT.isVector())
+ continue;
+
+ if (RegisterVT.isFixedLengthVector())
+ RegisterVT = TLI->getContainerForFixedLengthVector(RegisterVT);
+
+ if (!FirstVMaskAssigned && RegisterVT.getVectorElementType() == MVT::i1) {
+ RVVArgInfos.push_back({1, RegisterVT, true});
+ FirstVMaskAssigned = true;
+ continue;
+ }
+
+ RVVArgInfos.push_back({1, RegisterVT, false});
+ }
+ }
+}
+
+// Handle multiple args.
+template <>
+void RVVArgDispatcher::constructArgInfos<Type *>(ArrayRef<Type *> TypeList) {
const DataLayout &DL = MF->getDataLayout();
const Function &F = MF->getFunction();
LLVMContext &Context = F.getContext();
@@ -21090,8 +21167,11 @@ void RVVArgDispatcher::constructArgInfos(ArrayRef<Type *> TypeList) {
EVT VT = TLI->getValueType(DL, ElemTy);
MVT RegisterVT =
TLI->getRegisterTypeForCallingConv(Context, F.getCallingConv(), VT);
+ unsigned NumRegs =
+ TLI->getNumRegistersForCallingConv(Context, F.getCallingConv(), VT);
- RVVArgInfos.push_back({STy->getNumElements(), RegisterVT, false});
+ RVVArgInfos.push_back(
+ {NumRegs * STy->getNumElements(), RegisterVT, false});
} else {
SmallVector<EVT, 4> ValueVTs;
ComputeValueVTs(*TLI, DL, Ty, ValueVTs);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index c28552354bf422..a2456f2fab66b1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -1041,13 +1041,16 @@ class RVVArgDispatcher {
bool FirstVMask = false;
};
+ template <typename Arg>
RVVArgDispatcher(const MachineFunction *MF, const RISCVTargetLowering *TLI,
- ArrayRef<Type *> TypeList)
+ ArrayRef<Arg> ArgList)
: MF(MF), TLI(TLI) {
- constructArgInfos(TypeList);
+ constructArgInfos(ArgList);
compute();
}
+ RVVArgDispatcher() = default;
+
MCPhysReg getNextPhysReg();
private:
@@ -1059,7 +1062,7 @@ class RVVArgDispatcher {
unsigned CurIdx = 0;
- void constructArgInfos(ArrayRef<Type *> TypeList);
+ template <typename Arg> void constructArgInfos(ArrayRef<Arg> Ret);
void compute();
void allocatePhysReg(unsigned NF = 1, unsigned LMul = 1,
unsigned StartReg = 0);
diff --git a/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll b/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
index 90edb994ce8222..647d3158b6167f 100644
--- a/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
@@ -249,3 +249,119 @@ define <vscale x 1 x i64> @case4_2(<vscale x 1 x i64> %0, {<vscale x 8 x i64>, <
%add = add <vscale x 1 x i64> %0, %2
ret <vscale x 1 x i64> %add
}
+
+declare <vscale x 1 x i64> @callee1()
+declare void @callee2(<vscale x 1 x i64>)
+declare void @callee3(<vscale x 4 x i32>)
+define void @caller() {
+; RV32-LABEL: caller:
+; RV32: # %bb.0:
+; RV32-NEXT: addi sp, sp, -16
+; RV32-NEXT: .cfi_def_cfa_offset 16
+; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
+; RV32-NEXT: .cfi_offset ra, -4
+; RV32-NEXT: call callee1
+; RV32-NEXT: vsetvli a0, zero, e64, m1, ta, ma
+; RV32-NEXT: vadd.vv v8, v8, v8
+; RV32-NEXT: call callee2
+; RV32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
+; RV32-NEXT: addi sp, sp, 16
+; RV32-NEXT: ret
+;
+; RV64-LABEL: caller:
+; RV64: # %bb.0:
+; RV64-NEXT: addi sp, sp, -16
+; RV64-NEXT: .cfi_def_cfa_offset 16
+; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
+; RV64-NEXT: .cfi_offset ra, -8
+; RV64-NEXT: call callee1
+; RV64-NEXT: vsetvli a0, zero, e64, m1, ta, ma
+; RV64-NEXT: vadd.vv v8, v8, v8
+; RV64-NEXT: call callee2
+; RV64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
+; RV64-NEXT: addi sp, sp, 16
+; RV64-NEXT: ret
+ %a = call <vscale x 1 x i64> @callee1()
+ %add = add <vscale x 1 x i64> %a, %a
+ call void @callee2(<vscale x 1 x i64> %add)
+ ret void
+}
+
+declare {<vscale x 4 x i32>, <vscale x 4 x i32>} @callee_tuple()
+define void @caller_tuple() {
+; RV32-LABEL: caller_tuple:
+; RV32: # %bb.0:
+; RV32-NEXT: addi sp, sp, -16
+; RV32-NEXT: .cfi_def_cfa_offset 16
+; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
+; RV32-NEXT: .cfi_offset ra, -4
+; RV32-NEXT: call callee_tuple
+; RV32-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; RV32-NEXT: vadd.vv v8, v8, v10
+; RV32-NEXT: call callee3
+; RV32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
+; RV32-NEXT: addi sp, sp, 16
+; RV32-NEXT: ret
+;
+; RV64-LABEL: caller_tuple:
+; RV64: # %bb.0:
+; RV64-NEXT: addi sp, sp, -16
+; RV64-NEXT: .cfi_def_cfa_offset 16
+; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
+; RV64-NEXT: .cfi_offset ra, -8
+; RV64-NEXT: call callee_tuple
+; RV64-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; RV64-NEXT: vadd.vv v8, v8, v10
+; RV64-NEXT: call callee3
+; RV64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
+; RV64-NEXT: addi sp, sp, 16
+; RV64-NEXT: ret
+ %a = call {<vscale x 4 x i32>, <vscale x 4 x i32>} @callee_tuple()
+ %b = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %a, 0
+ %c = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %a, 1
+ %add = add <vscale x 4 x i32> %b, %c
+ call void @callee3(<vscale x 4 x i32> %add)
+ ret void
+}
+
+declare {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} @callee_nested()
+define void @caller_nested() {
+; RV32-LABEL: caller_nested:
+; RV32: # %bb.0:
+; RV32-NEXT: addi sp, sp, -16
+; RV32-NEXT: .cfi_def_cfa_offset 16
+; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
+; RV32-NEXT: .cfi_offset ra, -4
+; RV32-NEXT: call callee_nested
+; RV32-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; RV32-NEXT: vadd.vv v8, v8, v10
+; RV32-NEXT: vadd.vv v8, v8, v12
+; RV32-NEXT: call callee3
+; RV32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
+; RV32-NEXT: addi sp, sp, 16
+; RV32-NEXT: ret
+;
+; RV64-LABEL: caller_nested:
+; RV64: # %bb.0:
+; RV64-NEXT: addi sp, sp, -16
+; RV64-NEXT: .cfi_def_cfa_offset 16
+; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
+; RV64-NEXT: .cfi_offset ra, -8
+; RV64-NEXT: call callee_nested
+; RV64-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; RV64-NEXT: vadd.vv v8, v8, v10
+; RV64-NEXT: vadd.vv v8, v8, v12
+; RV64-NEXT: call callee3
+; RV64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
+; RV64-NEXT: addi sp, sp, 16
+; RV64-NEXT: ret
+ %a = call {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} @callee_nested()
+ %b = extractvalue {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} %a, 0
+ %c = extractvalue {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} %a, 1
+ %c0 = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %c, 0
+ %c1 = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %c, 1
+ %add0 = add <vscale x 4 x i32> %b, %c0
+ %add1 = add <vscale x 4 x i32> %add0, %c1
+ call void @callee3(<vscale x 4 x i32> %add1)
+ ret void
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/87736
More information about the llvm-commits
mailing list