[llvm] [NVPTX] Vectorize loads when lowering of byval calls, misc. cleanup (PR #151070)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 30 22:04:34 PDT 2025
https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/151070
>From b28612948c87c35b76521c2e8967375364935e17 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Mon, 28 Jul 2025 22:46:52 +0000
Subject: [PATCH 1/3] pre-commit tests
---
.../test/CodeGen/NVPTX/byval-arg-vectorize.ll | 40 +++++++++++++++++++
1 file changed, 40 insertions(+)
create mode 100644 llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll
diff --git a/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll b/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll
new file mode 100644
index 0000000000000..4756b16751f39
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll
@@ -0,0 +1,40 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mcpu=sm_70 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -mcpu=sm_70 | %ptxas-verify -arch=sm_70 %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+%struct.double2 = type { double, double }
+
+declare %struct.double2 @add(ptr align(16) byval(%struct.double2), ptr align(16) byval(%struct.double2))
+
+define void @call_byval(ptr %out, ptr %in1, ptr %in2) {
+; CHECK-LABEL: call_byval(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<12>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [call_byval_param_0];
+; CHECK-NEXT: { // callseq 0, 0
+; CHECK-NEXT: .param .align 16 .b8 param0[16];
+; CHECK-NEXT: .param .align 16 .b8 param1[16];
+; CHECK-NEXT: .param .align 8 .b8 retval0[16];
+; CHECK-NEXT: ld.param.b64 %rd2, [call_byval_param_2];
+; CHECK-NEXT: ld.b64 %rd3, [%rd2+8];
+; CHECK-NEXT: ld.b64 %rd4, [%rd2];
+; CHECK-NEXT: st.param.v2.b64 [param1], {%rd4, %rd3};
+; CHECK-NEXT: ld.param.b64 %rd5, [call_byval_param_1];
+; CHECK-NEXT: ld.b64 %rd6, [%rd5+8];
+; CHECK-NEXT: ld.b64 %rd7, [%rd5];
+; CHECK-NEXT: st.param.v2.b64 [param0], {%rd7, %rd6};
+; CHECK-NEXT: call.uni (retval0), add, (param0, param1);
+; CHECK-NEXT: ld.param.b64 %rd8, [retval0+8];
+; CHECK-NEXT: ld.param.b64 %rd9, [retval0];
+; CHECK-NEXT: } // callseq 0
+; CHECK-NEXT: st.b64 [%rd1+8], %rd8;
+; CHECK-NEXT: st.b64 [%rd1], %rd9;
+; CHECK-NEXT: ret;
+ %call = call %struct.double2 @add(ptr align(16) byval(%struct.double2) %in1, ptr align(16) byval(%struct.double2) %in2)
+ store %struct.double2 %call, ptr %out, align 16
+ ret void
+}
>From cdfc0e34cc1c45ef2341bd85077cd1ca9bd28303 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Tue, 29 Jul 2025 02:02:37 +0000
Subject: [PATCH 2/3] [NVPTX] Vectorize loads when lowering of byval calls,
misc. cleanup
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 376 +++++++++---------
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 6 -
.../test/CodeGen/NVPTX/byval-arg-vectorize.ll | 10 +-
.../CodeGen/NVPTX/convert-call-to-indirect.ll | 24 +-
.../CodeGen/NVPTX/lower-args-gridconstant.ll | 66 +--
.../CodeGen/NVPTX/param-vectorize-device.ll | 22 +-
6 files changed, 245 insertions(+), 259 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index f79b8629f01e2..3e48e61f8c6c8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -382,6 +382,51 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
}
}
+static EVT getVectorizedVT(EVT VT, unsigned N, LLVMContext &C) {
+ if (N == 1)
+ return VT;
+
+ const unsigned PackingAmt = VT.isVector() ? VT.getVectorNumElements() : 1;
+ return EVT::getVectorVT(C, VT.getScalarType(), N * PackingAmt);
+}
+
+static SDValue getExtractVectorizedValue(SDValue V, unsigned I, EVT VT,
+ const SDLoc &dl, SelectionDAG &DAG) {
+ if (V.getValueType() == VT) {
+ assert(I == 0 && "Index must be 0 for scalar value");
+ return V;
+ }
+
+ if (!VT.isVector())
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, V,
+ DAG.getVectorIdxConstant(I, dl));
+
+ return DAG.getNode(
+ ISD::EXTRACT_SUBVECTOR, dl, VT, V,
+ DAG.getVectorIdxConstant(I * VT.getVectorNumElements(), dl));
+}
+
+template <typename T>
+static inline SDValue getBuildVectorizedValue(T GetElement, unsigned N,
+ const SDLoc &dl,
+ SelectionDAG &DAG) {
+ if (N == 1)
+ return GetElement(0);
+
+ SmallVector<SDValue, 8> Values;
+ for (const unsigned I : llvm::seq(N)) {
+ SDValue Val = GetElement(I);
+ if (Val.getValueType().isVector())
+ DAG.ExtractVectorElements(Val, Values);
+ else
+ Values.push_back(Val);
+ }
+
+ EVT VT = EVT::getVectorVT(*DAG.getContext(), Values[0].getValueType(),
+ Values.size());
+ return DAG.getBuildVector(VT, dl, Values);
+}
+
/// PromoteScalarIntegerPTX
/// Used to make sure the arguments/returns are suitable for passing
/// and promote them to a larger size if they're not.
@@ -420,9 +465,10 @@ static EVT promoteScalarIntegerPTX(const EVT VT) {
// parameter starting at index Idx using a single vectorized op of
// size AccessSize. If so, it returns the number of param pieces
// covered by the vector op. Otherwise, it returns 1.
-static unsigned CanMergeParamLoadStoresStartingAt(
+template <typename T>
+static unsigned canMergeParamLoadStoresStartingAt(
unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
- const SmallVectorImpl<uint64_t> &Offsets, Align ParamAlignment) {
+ const SmallVectorImpl<T> &Offsets, Align ParamAlignment) {
// Can't vectorize if param alignment is not sufficient.
if (ParamAlignment < AccessSize)
@@ -472,10 +518,11 @@ static unsigned CanMergeParamLoadStoresStartingAt(
// of the same size as ValueVTs indicating how each piece should be
// loaded/stored (i.e. as a scalar, or as part of a vector
// load/store).
+template <typename T>
static SmallVector<unsigned, 16>
VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
- const SmallVectorImpl<uint64_t> &Offsets,
- Align ParamAlignment, bool IsVAArg = false) {
+ const SmallVectorImpl<T> &Offsets, Align ParamAlignment,
+ bool IsVAArg = false) {
// Set vector size to match ValueVTs and mark all elements as
// scalars by default.
@@ -486,7 +533,7 @@ VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
const auto GetNumElts = [&](unsigned I) -> unsigned {
for (const unsigned AccessSize : {16, 8, 4, 2}) {
- const unsigned NumElts = CanMergeParamLoadStoresStartingAt(
+ const unsigned NumElts = canMergeParamLoadStoresStartingAt(
I, AccessSize, ValueVTs, Offsets, ParamAlignment);
assert((NumElts == 1 || NumElts == 2 || NumElts == 4) &&
"Unexpected vectorization size");
@@ -1384,6 +1431,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
Type *RetTy = CLI.RetTy;
const CallBase *CB = CLI.CB;
const DataLayout &DL = DAG.getDataLayout();
+ LLVMContext &Ctx = *DAG.getContext();
const auto GetI32 = [&](const unsigned I) {
return DAG.getConstant(I, dl, MVT::i32);
@@ -1476,15 +1524,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const SDValue ParamSymbol =
getCallParamSymbol(DAG, IsVAArg ? FirstVAArg : ArgI, MVT::i32);
- SmallVector<EVT, 16> VTs;
- SmallVector<uint64_t, 16> Offsets;
-
assert((!IsByVal || Arg.IndirectType) &&
"byval arg must have indirect type");
Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
- ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset);
- assert(VTs.size() == Offsets.size() && "Size mismatch");
- assert((IsByVal || VTs.size() == ArgOuts.size()) && "Size mismatch");
const Align ArgAlign = [&]() {
if (IsByVal) {
@@ -1492,17 +1534,14 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// so we don't need to worry whether it's naturally aligned or not.
// See TargetLowering::LowerCallTo().
const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
- const Align ByValAlign = getFunctionByValParamAlign(
- CB->getCalledFunction(), ETy, InitialAlign, DL);
- if (IsVAArg)
- VAOffset = alignTo(VAOffset, ByValAlign);
- return ByValAlign;
+ return getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
+ InitialAlign, DL);
}
return getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
}();
- const unsigned TypeSize = DL.getTypeAllocSize(ETy);
- assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
+ const unsigned TySize = DL.getTypeAllocSize(ETy);
+ assert((!IsByVal || TySize == ArgOuts[0].Flags.getByValSize()) &&
"type size mismatch");
const SDValue ArgDeclare = [&]() {
@@ -1510,105 +1549,119 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
return VADeclareParam;
if (IsByVal || shouldPassAsArray(Arg.Ty))
- return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TypeSize);
+ return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TySize);
assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) &&
"Only int and float types are supported as non-array arguments");
- return MakeDeclareScalarParam(ParamSymbol, TypeSize);
+ return MakeDeclareScalarParam(ParamSymbol, TySize);
}();
- // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
- // than 32-bits are sign extended or zero extended, depending on
- // whether they are signed or unsigned types. This case applies
- // only to scalar parameters and not to aggregate values.
- const bool ExtendIntegerParam =
- Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
+ if (IsByVal) {
+ assert(ArgOutVals.size() == 1 && "We must pass only one value as byval");
+ SDValue SrcPtr = ArgOutVals[0];
+ const auto PointerInfo = refinePtrAS(SrcPtr, DAG, DL, *this);
+ const Align BaseSrcAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
- const auto GetStoredValue = [&](const unsigned I, EVT EltVT,
- const MaybeAlign PartAlign) {
- if (IsByVal) {
- SDValue Ptr = ArgOutVals[0];
- auto MPI = refinePtrAS(Ptr, DAG, DL, *this);
- SDValue SrcAddr =
- DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(Offsets[I]));
-
- return DAG.getLoad(EltVT, dl, CallChain, SrcAddr, MPI, PartAlign);
+ if (IsVAArg)
+ VAOffset = alignTo(VAOffset, ArgAlign);
+
+ SmallVector<EVT, 4> ValueVTs, MemVTs;
+ SmallVector<TypeSize, 4> Offsets;
+ ComputeValueVTs(*this, DL, ETy, ValueVTs, &MemVTs, &Offsets);
+
+ unsigned J = 0;
+ const auto VI = VectorizePTXValueVTs(MemVTs, Offsets, ArgAlign, IsVAArg);
+ for (const unsigned NumElts : VI) {
+ EVT LoadVT = getVectorizedVT(MemVTs[J], NumElts, Ctx);
+ Align SrcAlign = commonAlignment(BaseSrcAlign, Offsets[J]);
+ SDValue SrcAddr = DAG.getObjectPtrOffset(dl, SrcPtr, Offsets[J]);
+ SDValue SrcLoad =
+ DAG.getLoad(LoadVT, dl, CallChain, SrcAddr, PointerInfo, SrcAlign);
+
+ TypeSize ParamOffset = Offsets[J].getWithIncrement(VAOffset);
+ Align ParamAlign = commonAlignment(ArgAlign, ParamOffset);
+ SDValue ParamAddr =
+ DAG.getObjectPtrOffset(dl, ParamSymbol, ParamOffset);
+ SDValue StoreParam =
+ DAG.getStore(ArgDeclare, dl, SrcLoad, ParamAddr,
+ MachinePointerInfo(ADDRESS_SPACE_PARAM), ParamAlign);
+ CallPrereqs.push_back(StoreParam);
+
+ J += NumElts;
}
- SDValue StVal = ArgOutVals[I];
- assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
- StVal.getValueType() &&
- "OutVal type should always be legal");
-
- const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
- const EVT StoreVT =
- ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
-
- return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl);
- };
-
- const auto VectorInfo =
- VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
-
- unsigned J = 0;
- for (const unsigned NumElts : VectorInfo) {
- const int CurOffset = Offsets[J];
- const EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
-
- if (IsVAArg && !IsByVal)
- // Align each part of the variadic argument to their type.
- VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
-
- assert((IsVAArg || VAOffset == 0) &&
- "VAOffset must be 0 for non-VA args");
+ if (IsVAArg)
+ VAOffset += TySize;
+ } else {
+ SmallVector<EVT, 16> VTs;
+ SmallVector<uint64_t, 16> Offsets;
+ ComputePTXValueVTs(*this, DL, Arg.Ty, VTs, &Offsets, VAOffset);
+ assert(VTs.size() == Offsets.size() && "Size mismatch");
+ assert(VTs.size() == ArgOuts.size() && "Size mismatch");
+
+ // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
+ // than 32-bits are sign extended or zero extended, depending on
+ // whether they are signed or unsigned types. This case applies
+ // only to scalar parameters and not to aggregate values.
+ const bool ExtendIntegerParam =
+ Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
+
+ const auto GetStoredValue = [&](const unsigned I) {
+ SDValue StVal = ArgOutVals[I];
+ assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
+ StVal.getValueType() &&
+ "OutVal type should always be legal");
+
+ const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
+ const EVT StoreVT =
+ ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
+
+ return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl);
+ };
+
+ unsigned J = 0;
+ const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
+ for (const unsigned NumElts : VI) {
+ const EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
+
+ unsigned Offset;
+ if (IsVAArg) {
+ // TODO: We may need to support vector types that can be passed
+ // as scalars in variadic arguments.
+ assert(NumElts == 1 &&
+ "Vectorization should be disabled for vaargs.");
+
+ // Align each part of the variadic argument to their type.
+ VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
+ Offset = VAOffset;
+
+ const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
+ VAOffset += DL.getTypeAllocSize(TheStoreType.getTypeForEVT(Ctx));
+ } else {
+ assert(VAOffset == 0 && "VAOffset must be 0 for non-VA args");
+ Offset = Offsets[J];
+ }
- const unsigned Offset =
- (VAOffset + ((IsVAArg && !IsByVal) ? 0 : CurOffset));
- SDValue Ptr =
- DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset));
+ SDValue Ptr =
+ DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset));
- const MaybeAlign CurrentAlign = ExtendIntegerParam
- ? MaybeAlign(std::nullopt)
- : commonAlignment(ArgAlign, Offset);
+ const MaybeAlign CurrentAlign = ExtendIntegerParam
+ ? MaybeAlign(std::nullopt)
+ : commonAlignment(ArgAlign, Offset);
- SDValue Val;
- if (NumElts == 1) {
- Val = GetStoredValue(J, EltVT, CurrentAlign);
- } else {
- SmallVector<SDValue, 8> StoreVals;
- for (const unsigned K : llvm::seq(NumElts)) {
- SDValue ValJ = GetStoredValue(J + K, EltVT, CurrentAlign);
- if (ValJ.getValueType().isVector())
- DAG.ExtractVectorElements(ValJ, StoreVals);
- else
- StoreVals.push_back(ValJ);
- }
+ SDValue Val = getBuildVectorizedValue(
+ [&](unsigned K) { return GetStoredValue(J + K); }, NumElts, dl,
+ DAG);
- EVT VT = EVT::getVectorVT(
- *DAG.getContext(), StoreVals[0].getValueType(), StoreVals.size());
- Val = DAG.getBuildVector(VT, dl, StoreVals);
- }
+ SDValue StoreParam =
+ DAG.getStore(ArgDeclare, dl, Val, Ptr,
+ MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
+ CallPrereqs.push_back(StoreParam);
- SDValue StoreParam =
- DAG.getStore(ArgDeclare, dl, Val, Ptr,
- MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
- CallPrereqs.push_back(StoreParam);
-
- // TODO: We may need to support vector types that can be passed
- // as scalars in variadic arguments.
- if (IsVAArg && !IsByVal) {
- assert(NumElts == 1 &&
- "Vectorization is expected to be disabled for variadics.");
- const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
- VAOffset +=
- DL.getTypeAllocSize(TheStoreType.getTypeForEVT(*DAG.getContext()));
+ J += NumElts;
}
-
- J += NumElts;
}
- if (IsVAArg && IsByVal)
- VAOffset += TypeSize;
}
// Handle Result
@@ -1676,17 +1729,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
CallPrereqs.push_back(PrototypeDeclare);
}
- if (ConvertToIndirectCall) {
- // Copy the function ptr to a ptx register and use the register to call the
- // function.
- const MVT DestVT = Callee.getValueType().getSimpleVT();
- MachineRegisterInfo &MRI = DAG.getMachineFunction().getRegInfo();
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- Register DestReg = MRI.createVirtualRegister(TLI.getRegClassFor(DestVT));
- auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee);
- Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
- }
-
const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0;
const unsigned NumArgs =
std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size());
@@ -1703,10 +1745,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
if (!Ins.empty()) {
SmallVector<EVT, 16> VTs;
SmallVector<uint64_t, 16> Offsets;
- ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
+ ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
assert(VTs.size() == Ins.size() && "Bad value decomposition");
const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
+ const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
// 32-bits are sign extended or zero extended, depending on whether
@@ -1714,9 +1757,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const bool ExtendIntegerRetVal =
RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
- const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
unsigned I = 0;
- for (const unsigned NumElts : VectorInfo) {
+ const auto VI = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
+ for (const unsigned NumElts : VI) {
const MaybeAlign CurrentAlign =
ExtendIntegerRetVal ? MaybeAlign(std::nullopt)
: commonAlignment(RetAlign, Offsets[I]);
@@ -1724,16 +1767,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
const EVT LoadVT =
ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
-
- const unsigned PackingAmt =
- LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;
-
- const EVT VecVT = NumElts == 1 ? LoadVT
- : EVT::getVectorVT(*DAG.getContext(),
- LoadVT.getScalarType(),
- NumElts * PackingAmt);
-
- const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
+ const EVT VecVT = getVectorizedVT(LoadVT, NumElts, Ctx);
SDValue Ptr =
DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
@@ -1742,17 +1776,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
LoadChains.push_back(R.getValue(1));
-
- if (NumElts == 1)
- ProxyRegOps.push_back(R);
- else
- for (const unsigned J : llvm::seq(NumElts)) {
- SDValue Elt = DAG.getNode(
- LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
- : ISD::EXTRACT_VECTOR_ELT,
- dl, LoadVT, R, DAG.getVectorIdxConstant(J * PackingAmt, dl));
- ProxyRegOps.push_back(Elt);
- }
+ for (const unsigned J : llvm::seq(NumElts))
+ ProxyRegOps.push_back(getExtractVectorizedValue(R, J, LoadVT, dl, DAG));
I += NumElts;
}
}
@@ -3227,11 +3252,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
- MachineFunction &MF = DAG.getMachineFunction();
const DataLayout &DL = DAG.getDataLayout();
auto PtrVT = getPointerTy(DAG.getDataLayout());
- const Function *F = &MF.getFunction();
+ const Function &F = DAG.getMachineFunction().getFunction();
SDValue Root = DAG.getRoot();
SmallVector<SDValue, 16> OutChains;
@@ -3247,7 +3271,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
// See similar issue in LowerCall.
auto AllIns = ArrayRef(Ins);
- for (const auto &Arg : F->args()) {
+ for (const auto &Arg : F.args()) {
const auto ArgIns = AllIns.take_while(
[&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); });
AllIns = AllIns.drop_front(ArgIns.size());
@@ -3287,7 +3311,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer");
SDValue P;
- if (isKernelFunction(*F)) {
+ if (isKernelFunction(F)) {
P = ArgSymbol;
P.getNode()->setIROrder(Arg.getArgNo() + 1);
} else {
@@ -3305,43 +3329,27 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
assert(VTs.size() == Offsets.size() && "Size mismatch");
const Align ArgAlign = getFunctionArgumentAlignment(
- F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
+ &F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
- const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
unsigned I = 0;
- for (const unsigned NumElts : VectorInfo) {
+ const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
+ for (const unsigned NumElts : VI) {
// i1 is loaded/stored as i8
const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
- // If the element is a packed type (ex. v2f16, v4i8, etc) holding
- // multiple elements.
- const unsigned PackingAmt =
- LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;
-
- const EVT VecVT =
- NumElts == 1
- ? LoadVT
- : EVT::getVectorVT(F->getContext(), LoadVT.getScalarType(),
- NumElts * PackingAmt);
+ const EVT VecVT = getVectorizedVT(LoadVT, NumElts, *DAG.getContext());
SDValue VecAddr = DAG.getObjectPtrOffset(
dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
- const MaybeAlign PartAlign = commonAlignment(ArgAlign, Offsets[I]);
+ const Align PartAlign = commonAlignment(ArgAlign, Offsets[I]);
SDValue P =
DAG.getLoad(VecVT, dl, Root, VecAddr,
MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign,
MachineMemOperand::MODereferenceable |
MachineMemOperand::MOInvariant);
- if (P.getNode())
- P.getNode()->setIROrder(Arg.getArgNo() + 1);
+ P.getNode()->setIROrder(Arg.getArgNo() + 1);
for (const unsigned J : llvm::seq(NumElts)) {
- SDValue Elt =
- NumElts == 1
- ? P
- : DAG.getNode(LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
- : ISD::EXTRACT_VECTOR_ELT,
- dl, LoadVT, P,
- DAG.getVectorIdxConstant(J * PackingAmt, dl));
+ SDValue Elt = getExtractVectorizedValue(P, J, LoadVT, dl, DAG);
Elt = correctParamType(Elt, ArgIns[I + J].VT, ArgIns[I + J].Flags,
DAG, dl);
@@ -3364,9 +3372,8 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
const SmallVectorImpl<ISD::OutputArg> &Outs,
const SmallVectorImpl<SDValue> &OutVals,
const SDLoc &dl, SelectionDAG &DAG) const {
- const MachineFunction &MF = DAG.getMachineFunction();
- const Function &F = MF.getFunction();
- Type *RetTy = MF.getFunction().getReturnType();
+ const Function &F = DAG.getMachineFunction().getFunction();
+ Type *RetTy = F.getReturnType();
if (RetTy->isVoidTy()) {
assert(OutVals.empty() && Outs.empty() && "Return value expected for void");
@@ -3374,10 +3381,9 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
}
const DataLayout &DL = DAG.getDataLayout();
- SmallVector<EVT, 16> VTs;
- SmallVector<uint64_t, 16> Offsets;
- ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
- assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
+
+ const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
+ const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
// 32-bits are sign extended or zero extended, depending on whether
@@ -3385,6 +3391,11 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
const bool ExtendIntegerRetVal =
RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
+ SmallVector<EVT, 16> VTs;
+ SmallVector<uint64_t, 16> Offsets;
+ ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
+ assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
+
const auto GetRetVal = [&](unsigned I) -> SDValue {
SDValue RetVal = OutVals[I];
assert(promoteScalarIntegerPTX(RetVal.getValueType()) ==
@@ -3397,33 +3408,16 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
return correctParamType(RetVal, StoreVT, Outs[I].Flags, DAG, dl);
};
- const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
- const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
unsigned I = 0;
- for (const unsigned NumElts : VectorInfo) {
+ const auto VI = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
+ for (const unsigned NumElts : VI) {
const MaybeAlign CurrentAlign = ExtendIntegerRetVal
? MaybeAlign(std::nullopt)
: commonAlignment(RetAlign, Offsets[I]);
- SDValue Val;
- if (NumElts == 1) {
- Val = GetRetVal(I);
- } else {
- SmallVector<SDValue, 4> StoreVals;
- for (const unsigned J : llvm::seq(NumElts)) {
- SDValue ValJ = GetRetVal(I + J);
- if (ValJ.getValueType().isVector())
- DAG.ExtractVectorElements(ValJ, StoreVals);
- else
- StoreVals.push_back(ValJ);
- }
-
- EVT VT = EVT::getVectorVT(F.getContext(), StoreVals[0].getValueType(),
- StoreVals.size());
- Val = DAG.getBuildVector(VT, dl, StoreVals);
- }
+ SDValue Val = getBuildVectorizedValue(
+ [&](unsigned K) { return GetRetVal(I + K); }, NumElts, dl, DAG);
- const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
SDValue Ptr =
DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 86d6f7c3fc3a3..45d3a6044826c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1802,8 +1802,6 @@ foreach is_convergent = [0, 1] in {
}
defvar call_inst = !cast<NVPTXInst>("CALL" # convergent_suffix);
- def : Pat<(call is_convergent, 1, imm:$rets, imm:$params, globaladdr:$addr, imm:$proto),
- (call_inst (to_tglobaladdr $addr), imm:$rets, imm:$params, imm:$proto)>;
def : Pat<(call is_convergent, 1, imm:$rets, imm:$params, i32:$addr, imm:$proto),
(call_inst $addr, imm:$rets, imm:$params, imm:$proto)>;
def : Pat<(call is_convergent, 1, imm:$rets, imm:$params, i64:$addr, imm:$proto),
@@ -1812,10 +1810,6 @@ foreach is_convergent = [0, 1] in {
defvar call_uni_inst = !cast<NVPTXInst>("CALL_UNI" # convergent_suffix);
def : Pat<(call is_convergent, 0, imm:$rets, imm:$params, globaladdr:$addr, 0),
(call_uni_inst (to_tglobaladdr $addr), imm:$rets, imm:$params)>;
- def : Pat<(call is_convergent, 0, imm:$rets, imm:$params, i32:$addr, 0),
- (call_uni_inst $addr, imm:$rets, imm:$params)>;
- def : Pat<(call is_convergent, 0, imm:$rets, imm:$params, i64:$addr, 0),
- (call_uni_inst $addr, imm:$rets, imm:$params)>;
}
def DECLARE_PARAM_array :
diff --git a/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll b/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll
index 4756b16751f39..9988d5b122cc1 100644
--- a/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll
+++ b/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll
@@ -20,13 +20,11 @@ define void @call_byval(ptr %out, ptr %in1, ptr %in2) {
; CHECK-NEXT: .param .align 16 .b8 param1[16];
; CHECK-NEXT: .param .align 8 .b8 retval0[16];
; CHECK-NEXT: ld.param.b64 %rd2, [call_byval_param_2];
-; CHECK-NEXT: ld.b64 %rd3, [%rd2+8];
-; CHECK-NEXT: ld.b64 %rd4, [%rd2];
-; CHECK-NEXT: st.param.v2.b64 [param1], {%rd4, %rd3};
+; CHECK-NEXT: ld.v2.b64 {%rd3, %rd4}, [%rd2];
+; CHECK-NEXT: st.param.v2.b64 [param1], {%rd3, %rd4};
; CHECK-NEXT: ld.param.b64 %rd5, [call_byval_param_1];
-; CHECK-NEXT: ld.b64 %rd6, [%rd5+8];
-; CHECK-NEXT: ld.b64 %rd7, [%rd5];
-; CHECK-NEXT: st.param.v2.b64 [param0], {%rd7, %rd6};
+; CHECK-NEXT: ld.v2.b64 {%rd6, %rd7}, [%rd5];
+; CHECK-NEXT: st.param.v2.b64 [param0], {%rd6, %rd7};
; CHECK-NEXT: call.uni (retval0), add, (param0, param1);
; CHECK-NEXT: ld.param.b64 %rd8, [retval0+8];
; CHECK-NEXT: ld.param.b64 %rd9, [retval0];
diff --git a/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll b/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
index 48209a8c88682..dd3e4ecddcd2e 100644
--- a/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
@@ -12,14 +12,14 @@ define %struct.64 @test_return_type_mismatch(ptr %p) {
; CHECK-NEXT: .reg .b64 %rd<40>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: ld.param.b64 %rd2, [test_return_type_mismatch_param_0];
+; CHECK-NEXT: ld.param.b64 %rd1, [test_return_type_mismatch_param_0];
; CHECK-NEXT: { // callseq 0, 0
; CHECK-NEXT: .param .b64 param0;
; CHECK-NEXT: .param .align 1 .b8 retval0[8];
-; CHECK-NEXT: st.param.b64 [param0], %rd2;
+; CHECK-NEXT: st.param.b64 [param0], %rd1;
; CHECK-NEXT: prototype_0 : .callprototype (.param .align 1 .b8 _[8]) _ (.param .b64 _);
-; CHECK-NEXT: mov.b64 %rd1, callee;
-; CHECK-NEXT: call (retval0), %rd1, (param0), prototype_0;
+; CHECK-NEXT: mov.b64 %rd2, callee;
+; CHECK-NEXT: call (retval0), %rd2, (param0), prototype_0;
; CHECK-NEXT: ld.param.b8 %rd3, [retval0+7];
; CHECK-NEXT: ld.param.b8 %rd4, [retval0+6];
; CHECK-NEXT: ld.param.b8 %rd5, [retval0+5];
@@ -90,16 +90,16 @@ define i64 @test_param_count_mismatch(ptr %p) {
; CHECK-NEXT: .reg .b64 %rd<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: ld.param.b64 %rd2, [test_param_count_mismatch_param_0];
+; CHECK-NEXT: ld.param.b64 %rd1, [test_param_count_mismatch_param_0];
; CHECK-NEXT: { // callseq 2, 0
; CHECK-NEXT: .param .b64 param0;
; CHECK-NEXT: .param .b64 param1;
; CHECK-NEXT: .param .b64 retval0;
-; CHECK-NEXT: st.param.b64 [param0], %rd2;
+; CHECK-NEXT: st.param.b64 [param0], %rd1;
; CHECK-NEXT: prototype_2 : .callprototype (.param .b64 _) _ (.param .b64 _, .param .b64 _);
; CHECK-NEXT: st.param.b64 [param1], 7;
-; CHECK-NEXT: mov.b64 %rd1, callee;
-; CHECK-NEXT: call (retval0), %rd1, (param0, param1), prototype_2;
+; CHECK-NEXT: mov.b64 %rd2, callee;
+; CHECK-NEXT: call (retval0), %rd2, (param0, param1), prototype_2;
; CHECK-NEXT: ld.param.b64 %rd3, [retval0];
; CHECK-NEXT: } // callseq 2
; CHECK-NEXT: st.param.b64 [func_retval0], %rd3;
@@ -114,14 +114,14 @@ define %struct.64 @test_return_type_mismatch_variadic(ptr %p) {
; CHECK-NEXT: .reg .b64 %rd<40>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: ld.param.b64 %rd2, [test_return_type_mismatch_variadic_param_0];
+; CHECK-NEXT: ld.param.b64 %rd1, [test_return_type_mismatch_variadic_param_0];
; CHECK-NEXT: { // callseq 3, 0
; CHECK-NEXT: .param .b64 param0;
; CHECK-NEXT: .param .align 1 .b8 retval0[8];
-; CHECK-NEXT: st.param.b64 [param0], %rd2;
+; CHECK-NEXT: st.param.b64 [param0], %rd1;
; CHECK-NEXT: prototype_3 : .callprototype (.param .align 1 .b8 _[8]) _ (.param .b64 _);
-; CHECK-NEXT: mov.b64 %rd1, callee_variadic;
-; CHECK-NEXT: call (retval0), %rd1, (param0), prototype_3;
+; CHECK-NEXT: mov.b64 %rd2, callee_variadic;
+; CHECK-NEXT: call (retval0), %rd2, (param0), prototype_3;
; CHECK-NEXT: ld.param.b8 %rd3, [retval0+7];
; CHECK-NEXT: ld.param.b8 %rd4, [retval0+6];
; CHECK-NEXT: ld.param.b8 %rd5, [retval0+5];
diff --git a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
index 38185c7bf30de..045704bdcd3fc 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
@@ -124,15 +124,15 @@ define ptx_kernel void @grid_const_escape(ptr byval(%struct.s) align 4 %input) {
; PTX-NEXT: .reg .b64 %rd<4>;
; PTX-EMPTY:
; PTX-NEXT: // %bb.0:
-; PTX-NEXT: mov.b64 %rd2, grid_const_escape_param_0;
-; PTX-NEXT: cvta.param.u64 %rd3, %rd2;
+; PTX-NEXT: mov.b64 %rd1, grid_const_escape_param_0;
+; PTX-NEXT: cvta.param.u64 %rd2, %rd1;
; PTX-NEXT: { // callseq 0, 0
; PTX-NEXT: .param .b64 param0;
; PTX-NEXT: .param .b32 retval0;
-; PTX-NEXT: st.param.b64 [param0], %rd3;
+; PTX-NEXT: st.param.b64 [param0], %rd2;
; PTX-NEXT: prototype_0 : .callprototype (.param .b32 _) _ (.param .b64 _);
-; PTX-NEXT: mov.b64 %rd1, escape;
-; PTX-NEXT: call (retval0), %rd1, (param0), prototype_0;
+; PTX-NEXT: mov.b64 %rd3, escape;
+; PTX-NEXT: call (retval0), %rd3, (param0), prototype_0;
; PTX-NEXT: } // callseq 0
; PTX-NEXT: ret;
; OPT-LABEL: define ptx_kernel void @grid_const_escape(
@@ -157,25 +157,25 @@ define ptx_kernel void @multiple_grid_const_escape(ptr byval(%struct.s) align 4
; PTX-NEXT: // %bb.0:
; PTX-NEXT: mov.b64 %SPL, __local_depot4;
; PTX-NEXT: cvta.local.u64 %SP, %SPL;
-; PTX-NEXT: mov.b64 %rd2, multiple_grid_const_escape_param_0;
+; PTX-NEXT: mov.b64 %rd1, multiple_grid_const_escape_param_0;
; PTX-NEXT: ld.param.b32 %r1, [multiple_grid_const_escape_param_1];
-; PTX-NEXT: mov.b64 %rd3, multiple_grid_const_escape_param_2;
-; PTX-NEXT: cvta.param.u64 %rd4, %rd3;
-; PTX-NEXT: cvta.param.u64 %rd5, %rd2;
-; PTX-NEXT: add.u64 %rd6, %SP, 0;
-; PTX-NEXT: add.u64 %rd7, %SPL, 0;
-; PTX-NEXT: st.local.b32 [%rd7], %r1;
+; PTX-NEXT: mov.b64 %rd2, multiple_grid_const_escape_param_2;
+; PTX-NEXT: cvta.param.u64 %rd3, %rd2;
+; PTX-NEXT: cvta.param.u64 %rd4, %rd1;
+; PTX-NEXT: add.u64 %rd5, %SP, 0;
+; PTX-NEXT: add.u64 %rd6, %SPL, 0;
+; PTX-NEXT: st.local.b32 [%rd6], %r1;
; PTX-NEXT: { // callseq 1, 0
; PTX-NEXT: .param .b64 param0;
; PTX-NEXT: .param .b64 param1;
; PTX-NEXT: .param .b64 param2;
; PTX-NEXT: .param .b32 retval0;
-; PTX-NEXT: st.param.b64 [param2], %rd4;
-; PTX-NEXT: st.param.b64 [param1], %rd6;
-; PTX-NEXT: st.param.b64 [param0], %rd5;
+; PTX-NEXT: st.param.b64 [param2], %rd3;
+; PTX-NEXT: st.param.b64 [param1], %rd5;
+; PTX-NEXT: st.param.b64 [param0], %rd4;
; PTX-NEXT: prototype_1 : .callprototype (.param .b32 _) _ (.param .b64 _, .param .b64 _, .param .b64 _);
-; PTX-NEXT: mov.b64 %rd1, escape3;
-; PTX-NEXT: call (retval0), %rd1, (param0, param1, param2), prototype_1;
+; PTX-NEXT: mov.b64 %rd7, escape3;
+; PTX-NEXT: call (retval0), %rd7, (param0, param1, param2), prototype_1;
; PTX-NEXT: } // callseq 1
; PTX-NEXT: ret;
; OPT-LABEL: define ptx_kernel void @multiple_grid_const_escape(
@@ -256,20 +256,20 @@ define ptx_kernel void @grid_const_partial_escape(ptr byval(i32) %input, ptr %ou
; PTX-NEXT: .reg .b64 %rd<6>;
; PTX-EMPTY:
; PTX-NEXT: // %bb.0:
-; PTX-NEXT: mov.b64 %rd2, grid_const_partial_escape_param_0;
-; PTX-NEXT: ld.param.b64 %rd3, [grid_const_partial_escape_param_1];
-; PTX-NEXT: cvta.to.global.u64 %rd4, %rd3;
-; PTX-NEXT: cvta.param.u64 %rd5, %rd2;
+; PTX-NEXT: mov.b64 %rd1, grid_const_partial_escape_param_0;
+; PTX-NEXT: ld.param.b64 %rd2, [grid_const_partial_escape_param_1];
+; PTX-NEXT: cvta.to.global.u64 %rd3, %rd2;
+; PTX-NEXT: cvta.param.u64 %rd4, %rd1;
; PTX-NEXT: ld.param.b32 %r1, [grid_const_partial_escape_param_0];
; PTX-NEXT: add.s32 %r2, %r1, %r1;
-; PTX-NEXT: st.global.b32 [%rd4], %r2;
+; PTX-NEXT: st.global.b32 [%rd3], %r2;
; PTX-NEXT: { // callseq 2, 0
; PTX-NEXT: .param .b64 param0;
; PTX-NEXT: .param .b32 retval0;
-; PTX-NEXT: st.param.b64 [param0], %rd5;
+; PTX-NEXT: st.param.b64 [param0], %rd4;
; PTX-NEXT: prototype_2 : .callprototype (.param .b32 _) _ (.param .b64 _);
-; PTX-NEXT: mov.b64 %rd1, escape;
-; PTX-NEXT: call (retval0), %rd1, (param0), prototype_2;
+; PTX-NEXT: mov.b64 %rd5, escape;
+; PTX-NEXT: call (retval0), %rd5, (param0), prototype_2;
; PTX-NEXT: } // callseq 2
; PTX-NEXT: ret;
; OPT-LABEL: define ptx_kernel void @grid_const_partial_escape(
@@ -295,21 +295,21 @@ define ptx_kernel i32 @grid_const_partial_escapemem(ptr byval(%struct.s) %input,
; PTX-NEXT: .reg .b64 %rd<6>;
; PTX-EMPTY:
; PTX-NEXT: // %bb.0:
-; PTX-NEXT: mov.b64 %rd2, grid_const_partial_escapemem_param_0;
-; PTX-NEXT: ld.param.b64 %rd3, [grid_const_partial_escapemem_param_1];
-; PTX-NEXT: cvta.to.global.u64 %rd4, %rd3;
-; PTX-NEXT: cvta.param.u64 %rd5, %rd2;
+; PTX-NEXT: mov.b64 %rd1, grid_const_partial_escapemem_param_0;
+; PTX-NEXT: ld.param.b64 %rd2, [grid_const_partial_escapemem_param_1];
+; PTX-NEXT: cvta.to.global.u64 %rd3, %rd2;
+; PTX-NEXT: cvta.param.u64 %rd4, %rd1;
; PTX-NEXT: ld.param.b32 %r1, [grid_const_partial_escapemem_param_0];
; PTX-NEXT: ld.param.b32 %r2, [grid_const_partial_escapemem_param_0+4];
-; PTX-NEXT: st.global.b64 [%rd4], %rd5;
+; PTX-NEXT: st.global.b64 [%rd3], %rd4;
; PTX-NEXT: add.s32 %r3, %r1, %r2;
; PTX-NEXT: { // callseq 3, 0
; PTX-NEXT: .param .b64 param0;
; PTX-NEXT: .param .b32 retval0;
-; PTX-NEXT: st.param.b64 [param0], %rd5;
+; PTX-NEXT: st.param.b64 [param0], %rd4;
; PTX-NEXT: prototype_3 : .callprototype (.param .b32 _) _ (.param .b64 _);
-; PTX-NEXT: mov.b64 %rd1, escape;
-; PTX-NEXT: call (retval0), %rd1, (param0), prototype_3;
+; PTX-NEXT: mov.b64 %rd5, escape;
+; PTX-NEXT: call (retval0), %rd5, (param0), prototype_3;
; PTX-NEXT: } // callseq 3
; PTX-NEXT: st.param.b32 [func_retval0], %r3;
; PTX-NEXT: ret;
diff --git a/llvm/test/CodeGen/NVPTX/param-vectorize-device.ll b/llvm/test/CodeGen/NVPTX/param-vectorize-device.ll
index a592b82614f43..51f6b00601069 100644
--- a/llvm/test/CodeGen/NVPTX/param-vectorize-device.ll
+++ b/llvm/test/CodeGen/NVPTX/param-vectorize-device.ll
@@ -150,8 +150,8 @@ define dso_local void @caller_St4x3(ptr nocapture noundef readonly byval(%struct
; CHECK: )
; CHECK: .param .align 16 .b8 param0[12];
; CHECK: .param .align 16 .b8 retval0[12];
- ; CHECK: st.param.v2.b32 [param0], {{{%r[0-9]+}}, {{%r[0-9]+}}};
- ; CHECK: st.param.b32 [param0+8], {{%r[0-9]+}};
+ ; CHECK-DAG: st.param.v2.b32 [param0], {{{%r[0-9]+}}, {{%r[0-9]+}}};
+ ; CHECK-DAG: st.param.b32 [param0+8], {{%r[0-9]+}};
; CHECK: call.uni (retval0), callee_St4x3, (param0);
; CHECK: ld.param.v2.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0];
; CHECK: ld.param.b32 {{%r[0-9]+}}, [retval0+8];
@@ -240,8 +240,8 @@ define dso_local void @caller_St4x5(ptr nocapture noundef readonly byval(%struct
; CHECK: )
; CHECK: .param .align 16 .b8 param0[20];
; CHECK: .param .align 16 .b8 retval0[20];
- ; CHECK: st.param.v4.b32 [param0], {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
- ; CHECK: st.param.b32 [param0+16], {{%r[0-9]+}};
+ ; CHECK-DAG: st.param.v4.b32 [param0], {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
+ ; CHECK-DAG: st.param.b32 [param0+16], {{%r[0-9]+}};
; CHECK: call.uni (retval0), callee_St4x5, (param0);
; CHECK: ld.param.v4.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0];
; CHECK: ld.param.b32 {{%r[0-9]+}}, [retval0+16];
@@ -296,8 +296,8 @@ define dso_local void @caller_St4x6(ptr nocapture noundef readonly byval(%struct
; CHECK: )
; CHECK: .param .align 16 .b8 param0[24];
; CHECK: .param .align 16 .b8 retval0[24];
- ; CHECK: st.param.v4.b32 [param0], {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
- ; CHECK: st.param.v2.b32 [param0+16], {{{%r[0-9]+}}, {{%r[0-9]+}}};
+ ; CHECK-DAG: st.param.v4.b32 [param0], {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
+ ; CHECK-DAG: st.param.v2.b32 [param0+16], {{{%r[0-9]+}}, {{%r[0-9]+}}};
; CHECK: call.uni (retval0), callee_St4x6, (param0);
; CHECK: ld.param.v4.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0];
; CHECK: ld.param.v2.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0+16];
@@ -358,9 +358,9 @@ define dso_local void @caller_St4x7(ptr nocapture noundef readonly byval(%struct
; CHECK: )
; CHECK: .param .align 16 .b8 param0[28];
; CHECK: .param .align 16 .b8 retval0[28];
- ; CHECK: st.param.v4.b32 [param0], {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
- ; CHECK: st.param.v2.b32 [param0+16], {{{%r[0-9]+}}, {{%r[0-9]+}}};
- ; CHECK: st.param.b32 [param0+24], {{%r[0-9]+}};
+ ; CHECK-DAG: st.param.v4.b32 [param0], {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
+ ; CHECK-DAG: st.param.v2.b32 [param0+16], {{{%r[0-9]+}}, {{%r[0-9]+}}};
+ ; CHECK-DAG: st.param.b32 [param0+24], {{%r[0-9]+}};
; CHECK: call.uni (retval0), callee_St4x7, (param0);
; CHECK: ld.param.v4.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0];
; CHECK: ld.param.v2.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0+16];
@@ -566,8 +566,8 @@ define dso_local void @caller_St8x3(ptr nocapture noundef readonly byval(%struct
; CHECK: )
; CHECK: .param .align 16 .b8 param0[24];
; CHECK: .param .align 16 .b8 retval0[24];
- ; CHECK: st.param.v2.b64 [param0], {{{%rd[0-9]+}}, {{%rd[0-9]+}}};
- ; CHECK: st.param.b64 [param0+16], {{%rd[0-9]+}};
+ ; CHECK-DAG: st.param.v2.b64 [param0], {{{%rd[0-9]+}}, {{%rd[0-9]+}}};
+ ; CHECK-DAG: st.param.b64 [param0+16], {{%rd[0-9]+}};
; CHECK: call.uni (retval0), callee_St8x3, (param0);
; CHECK: ld.param.v2.b64 {{{%rd[0-9]+}}, {{%rd[0-9]+}}}, [retval0];
; CHECK: ld.param.b64 {{%rd[0-9]+}}, [retval0+16];
>From ca2239237cb847eef14e980d8af5ed4d327b8289 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Thu, 31 Jul 2025 05:05:21 +0000
Subject: [PATCH 3/3] address comments
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 22 ++++++++++++---------
1 file changed, 13 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 3e48e61f8c6c8..124914d9e4e9d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -382,12 +382,16 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
}
}
+// We return an EVT that can hold N VTs
+// If the VT is a vector, the resulting EVT is a flat vector with the same
+// element type as VT's element type.
static EVT getVectorizedVT(EVT VT, unsigned N, LLVMContext &C) {
if (N == 1)
return VT;
- const unsigned PackingAmt = VT.isVector() ? VT.getVectorNumElements() : 1;
- return EVT::getVectorVT(C, VT.getScalarType(), N * PackingAmt);
+ return VT.isVector() ? EVT::getVectorVT(C, VT.getScalarType(),
+ VT.getVectorNumElements() * N)
+ : EVT::getVectorVT(C, VT, N);
}
static SDValue getExtractVectorizedValue(SDValue V, unsigned I, EVT VT,
@@ -407,9 +411,8 @@ static SDValue getExtractVectorizedValue(SDValue V, unsigned I, EVT VT,
}
template <typename T>
-static inline SDValue getBuildVectorizedValue(T GetElement, unsigned N,
- const SDLoc &dl,
- SelectionDAG &DAG) {
+static inline SDValue getBuildVectorizedValue(unsigned N, const SDLoc &dl,
+ SelectionDAG &DAG, T GetElement) {
if (N == 1)
return GetElement(0);
@@ -1650,9 +1653,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
? MaybeAlign(std::nullopt)
: commonAlignment(ArgAlign, Offset);
- SDValue Val = getBuildVectorizedValue(
- [&](unsigned K) { return GetStoredValue(J + K); }, NumElts, dl,
- DAG);
+ SDValue Val =
+ getBuildVectorizedValue(NumElts, dl, DAG, [&](unsigned K) {
+ return GetStoredValue(J + K);
+ });
SDValue StoreParam =
DAG.getStore(ArgDeclare, dl, Val, Ptr,
@@ -3416,7 +3420,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
: commonAlignment(RetAlign, Offsets[I]);
SDValue Val = getBuildVectorizedValue(
- [&](unsigned K) { return GetRetVal(I + K); }, NumElts, dl, DAG);
+ NumElts, dl, DAG, [&](unsigned K) { return GetRetVal(I + K); });
SDValue Ptr =
DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
More information about the llvm-commits
mailing list