[llvm-branch-commits] [llvm] e28c71d - Revert "[NVPTX] Rip out vestigial variadic support (NFC) (#202385)"
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jun 16 02:59:47 PDT 2026
Author: Dmitry Vasilyev
Date: 2026-06-16T13:59:44+04:00
New Revision: e28c71d953dcdb0eacaf174a214f3cdc2ca09263
URL: https://github.com/llvm/llvm-project/commit/e28c71d953dcdb0eacaf174a214f3cdc2ca09263
DIFF: https://github.com/llvm/llvm-project/commit/e28c71d953dcdb0eacaf174a214f3cdc2ca09263.diff
LOG: Revert "[NVPTX] Rip out vestigial variadic support (NFC) (#202385)"
This reverts commit e63cd40ccce67f9472af9676185d7c87157043b4.
Added:
Modified:
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/lib/Target/NVPTX/NVPTXISelLowering.h
llvm/lib/Target/NVPTX/NVPTXSubtarget.h
Removed:
################################################################################
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index cd5c27cebc182..b2efcb0f0d2b6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1360,21 +1360,25 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
const NVPTXMachineFunctionInfo *MFI =
MF ? MF->getInfo<NVPTXMachineFunctionInfo>() : nullptr;
+ bool IsFirst = true;
const bool IsKernelFunc = isKernelFunction(*F);
- assert(!F->isVarArg() && "VarArg functions lowered in ExpandVariadics");
-
- if (F->arg_empty()) {
+ if (F->arg_empty() && !F->isVarArg()) {
O << "()";
return;
}
O << "(\n";
- auto EmitParam = [&](const Argument &Arg) {
+ for (const Argument &Arg : F->args()) {
Type *Ty = Arg.getType();
const std::string ParamSym = TLI->getParamName(F, Arg.getArgNo());
+ if (!IsFirst)
+ O << ",\n";
+
+ IsFirst = false;
+
// Handle image/sampler parameters
if (IsKernelFunc) {
const PTXOpaqueType ArgOpaqueType = getPTXOpaqueType(Arg);
@@ -1398,7 +1402,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
llvm_unreachable("handled above");
}
O << ParamSym;
- return;
+ continue;
}
}
@@ -1420,7 +1424,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
<< "[" << DL.getTypeAllocSize(ETy) << "]";
- return;
+ continue;
}
if (shouldPassAsArray(Ty)) {
@@ -1434,7 +1438,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
<< "[" << DL.getTypeAllocSize(Ty) << "]";
- return;
+ continue;
}
// Just a scalar
auto *PTy = dyn_cast<PointerType>(Ty);
@@ -1468,7 +1472,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
O << " .align " << Arg.getParamAlign().valueOrOne().value() << " "
<< ParamSym;
- return;
+ continue;
}
// non-pointer scalar to kernel func
@@ -1479,7 +1483,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
else
O << getPTXFundamentalTypeStr(Ty);
O << " " << ParamSym;
- return;
+ continue;
}
// Non-kernel function, just print .param .b<size> for ABI
// and .reg .b<size> for non-ABI
@@ -1492,8 +1496,14 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
} else
Size = Ty->getPrimitiveSizeInBits();
O << "\t.param .b" << Size << " " << ParamSym;
- };
- interleave(F->args(), O, EmitParam, ",\n");
+ }
+
+ if (F->isVarArg()) {
+ if (!IsFirst)
+ O << ",\n";
+ O << "\t.param .align " << STI.getMaxRequiredAlignment() << " .b8 "
+ << TLI->getParamName(F, /* vararg */ -1) << "[]";
+ }
O << "\n)";
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 7555845847935..17d9f857312d6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -308,10 +308,12 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
LLVMContext &Ctx, CallingConv::ID CallConv,
Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
- SmallVectorImpl<TypeSize> &Offsets) {
+ SmallVectorImpl<uint64_t> &Offsets,
+ uint64_t StartingOffset = 0) {
SmallVector<EVT, 16> TempVTs;
- SmallVector<TypeSize, 16> TempOffsets;
- ComputeValueVTs(TLI, DL, Ty, TempVTs, /*MemVTs=*/nullptr, &TempOffsets);
+ SmallVector<uint64_t, 16> TempOffsets;
+ ComputeValueVTs(TLI, DL, Ty, TempVTs, /*MemVTs=*/nullptr, &TempOffsets,
+ StartingOffset);
for (const auto [VT, Off] : zip(TempVTs, TempOffsets)) {
MVT RegisterVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, VT);
@@ -426,9 +428,10 @@ static EVT promoteScalarIntegerPTX(const EVT VT) {
// parameter starting at index Idx using a single vectorized op of
// size AccessSize. If so, it returns the number of param pieces
// covered by the vector op. Otherwise, it returns 1.
+template <typename T>
static unsigned canMergeParamLoadStoresStartingAt(
unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
- const SmallVectorImpl<TypeSize> &Offsets, Align ParamAlignment) {
+ const SmallVectorImpl<T> &Offsets, Align ParamAlignment) {
// Can't vectorize if param alignment is not sufficient.
if (ParamAlignment < AccessSize)
@@ -478,10 +481,17 @@ static unsigned canMergeParamLoadStoresStartingAt(
// of the same size as ValueVTs indicating how each piece should be
// loaded/stored (i.e. as a scalar, or as part of a vector
// load/store).
+template <typename T>
static SmallVector<unsigned, 16>
VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
- const SmallVectorImpl<TypeSize> &Offsets,
- Align ParamAlignment) {
+ const SmallVectorImpl<T> &Offsets, Align ParamAlignment,
+ bool IsVAArg = false) {
+ // Set vector size to match ValueVTs and mark all elements as
+ // scalars by default.
+
+ if (IsVAArg)
+ return SmallVector<unsigned>(ValueVTs.size(), 1);
+
SmallVector<unsigned, 16> VectorInfo;
const auto GetNumElts = [&](unsigned I) -> unsigned {
@@ -796,6 +806,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// DEBUGTRAP can be lowered to PTX brkpt
setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal);
+ // Support varargs.
+ setOperationAction(ISD::VASTART, MVT::Other, Custom);
+ setOperationAction(ISD::VAARG, MVT::Other, Custom);
+ setOperationAction(ISD::VACOPY, MVT::Other, Expand);
+ setOperationAction(ISD::VAEND, MVT::Other, Expand);
+
setOperationAction({ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX},
{MVT::i16, MVT::i32, MVT::i64}, Legal);
// PTX abs.s is undefined for INT_MIN, so ISD::ABS (which requires
@@ -1191,7 +1207,8 @@ SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
std::string NVPTXTargetLowering::getPrototype(
const DataLayout &DL, Type *RetTy, const ArgListTy &Args,
- const SmallVectorImpl<ISD::OutputArg> &Outs, const CallBase &CB,
+ const SmallVectorImpl<ISD::OutputArg> &Outs,
+ std::optional<unsigned> FirstVAArg, const CallBase &CB,
unsigned UniqueCallSite) const {
auto PtrVT = getPointerTy(DL);
@@ -1232,13 +1249,20 @@ std::string NVPTXTargetLowering::getPrototype(
}
O << "_ (";
+ bool first = true;
+
+ const unsigned NumArgs = FirstVAArg.value_or(Args.size());
auto AllOuts = ArrayRef(Outs);
- auto MakeArg = [&](const unsigned I) {
+ for (const unsigned I : llvm::seq(NumArgs)) {
const auto ArgOuts =
AllOuts.take_while([I](auto O) { return O.OrigArgIndex == I; });
AllOuts = AllOuts.drop_front(ArgOuts.size());
Type *Ty = Args[I].Ty;
+ if (!first) {
+ O << ", ";
+ }
+ first = false;
if (ArgOuts[0].Flags.isByVal()) {
// Indirect calls need strict ABI alignment so we disable optimizations by
@@ -1250,33 +1274,34 @@ std::string NVPTXTargetLowering::getPrototype(
O << ".param .align " << ParamByValAlign.value() << " .b8 _["
<< ArgOuts[0].Flags.getByValSize() << "]";
- return;
- }
-
- if (shouldPassAsArray(Ty)) {
- Align ParamAlign =
- getPTXParamAlign(&CB, Ty, I + AttributeList::FirstArgIndex, DL);
- O << ".param .align " << ParamAlign.value() << " .b8 _["
- << DL.getTypeAllocSize(Ty) << "]";
- return;
- }
- // i8 types in IR will be i16 types in SDAG
- assert((getValueType(DL, Ty) == ArgOuts[0].VT ||
- (getValueType(DL, Ty) == MVT::i8 && ArgOuts[0].VT == MVT::i16)) &&
- "type mismatch between callee prototype and arguments");
- // scalar type
- unsigned sz = 0;
- if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
- sz = promoteScalarArgumentSize(ITy->getBitWidth());
- } else if (isa<PointerType>(Ty)) {
- sz = PtrVT.getSizeInBits();
} else {
- sz = Ty->getPrimitiveSizeInBits();
+ if (shouldPassAsArray(Ty)) {
+ Align ParamAlign =
+ getPTXParamAlign(&CB, Ty, I + AttributeList::FirstArgIndex, DL);
+ O << ".param .align " << ParamAlign.value() << " .b8 _["
+ << DL.getTypeAllocSize(Ty) << "]";
+ continue;
+ }
+ // i8 types in IR will be i16 types in SDAG
+ assert((getValueType(DL, Ty) == ArgOuts[0].VT ||
+ (getValueType(DL, Ty) == MVT::i8 && ArgOuts[0].VT == MVT::i16)) &&
+ "type mismatch between callee prototype and arguments");
+ // scalar type
+ unsigned sz = 0;
+ if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
+ sz = promoteScalarArgumentSize(ITy->getBitWidth());
+ } else if (isa<PointerType>(Ty)) {
+ sz = PtrVT.getSizeInBits();
+ } else {
+ sz = Ty->getPrimitiveSizeInBits();
+ }
+ O << ".param .b" << sz << " _";
}
- O << ".param .b" << sz << " _";
- };
- interleave(seq(Args.size()), O, MakeArg, ", ");
+ }
+ if (FirstVAArg)
+ O << (first ? "" : ",") << " .param .align "
+ << STI.getMaxRequiredAlignment() << " .b8 _[]";
O << ")";
if (shouldEmitPTXNoReturn(&CB, *nvTM))
O << " .noreturn";
@@ -1335,7 +1360,10 @@ static SDValue correctParamType(SDValue V, EVT ExpectedVT,
SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SmallVectorImpl<SDValue> &InVals) const {
- assert(!CLI.IsVarArg && "Vararg functions lowered in ExpandVariadics");
+ if (CLI.IsVarArg && (STI.getPTXVersion() < 60 || STI.getSmVersion() < 30))
+ report_fatal_error(
+ "Support for variadic functions (unsized array parameter) introduced "
+ "in PTX ISA version 6.0 and requires target sm_30.");
SelectionDAG &DAG = CLI.DAG;
SDLoc dl = CLI.DL;
@@ -1381,11 +1409,32 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
return Declare;
};
- // For each argument, we declare a param scalar or a param byte array in the
- // .param space, and store the argument value to that param scalar or array
- // starting at offset 0.
- assert(CLI.Args.size() == CLI.NumFixedArgs &&
- "function with extra arguments");
+ // Variadic arguments.
+ //
+ // Normally, for each argument, we declare a param scalar or a param
+ // byte array in the .param space, and store the argument value to that
+ // param scalar or array starting at offset 0.
+ //
+ // In the case of the first variadic argument, we declare a vararg byte array
+ // with size 0. The exact size of this array isn't known at this point, so
+ // it'll be patched later. All the variadic arguments will be stored to this
+ // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
+ // initially set to 0, so it can be used for non-variadic arguments (which use
+ // 0 offset) to simplify the code.
+ //
+ // After all vararg is processed, 'VAOffset' holds the size of the
+ // vararg byte array.
+ assert((CLI.IsVarArg || CLI.Args.size() == CLI.NumFixedArgs) &&
+ "Non-VarArg function with extra arguments");
+
+ const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
+ unsigned VAOffset = 0; // current offset in the param array
+
+ const SDValue VADeclareParam =
+ CLI.Args.size() > FirstVAArg
+ ? MakeDeclareArrayParam(getCallParamSymbol(DAG, FirstVAArg, MVT::i32),
+ Align(STI.getMaxRequiredAlignment()), 0)
+ : SDValue();
// Args.size() and Outs.size() need not match.
// Outs.size() will be larger
@@ -1411,9 +1460,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
AllOuts = AllOuts.drop_front(ArgOuts.size());
AllOutVals = AllOutVals.drop_front(ArgOuts.size());
+ const bool IsVAArg = (ArgI >= FirstVAArg);
const bool IsByVal = Arg.IsByVal;
- const SDValue ParamSymbol = getCallParamSymbol(DAG, ArgI, MVT::i32);
+ const SDValue ParamSymbol =
+ getCallParamSymbol(DAG, IsVAArg ? FirstVAArg : ArgI, MVT::i32);
assert((!IsByVal || Arg.IndirectType) &&
"byval arg must have indirect type");
@@ -1437,6 +1488,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
"type size mismatch");
const SDValue ArgDeclare = [&]() {
+ if (IsVAArg)
+ return VADeclareParam;
+
if (IsByVal || shouldPassAsArray(Arg.Ty))
return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TySize);
@@ -1453,12 +1507,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const auto PointerInfo = refinePtrAS(SrcPtr, DAG, DL, *this);
const Align BaseSrcAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
+ if (IsVAArg)
+ VAOffset = alignTo(VAOffset, ArgAlign);
+
SmallVector<EVT, 4> ValueVTs, MemVTs;
SmallVector<TypeSize, 4> Offsets;
ComputeValueVTs(*this, DL, ETy, ValueVTs, &MemVTs, &Offsets);
unsigned J = 0;
- const auto VI = VectorizePTXValueVTs(MemVTs, Offsets, ArgAlign);
+ const auto VI = VectorizePTXValueVTs(MemVTs, Offsets, ArgAlign, IsVAArg);
for (const unsigned NumElts : VI) {
EVT LoadVT = getVectorizedVT(MemVTs[J], NumElts, Ctx);
Align SrcAlign = commonAlignment(BaseSrcAlign, Offsets[J]);
@@ -1466,8 +1523,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SDValue SrcLoad =
DAG.getLoad(LoadVT, dl, CallChain, SrcAddr, PointerInfo, SrcAlign);
- Align ParamAlign = commonAlignment(ArgAlign, Offsets[J]);
- SDValue ParamAddr = DAG.getObjectPtrOffset(dl, ParamSymbol, Offsets[J]);
+ TypeSize ParamOffset = Offsets[J].getWithIncrement(VAOffset);
+ Align ParamAlign = commonAlignment(ArgAlign, ParamOffset);
+ SDValue ParamAddr =
+ DAG.getObjectPtrOffset(dl, ParamSymbol, ParamOffset);
SDValue StoreParam = DAG.getStore(
ArgDeclare, dl, SrcLoad, ParamAddr,
MachinePointerInfo(NVPTX::AddressSpace::DeviceParam), ParamAlign);
@@ -1475,10 +1534,13 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
J += NumElts;
}
+ if (IsVAArg)
+ VAOffset += TySize;
} else {
SmallVector<EVT, 16> VTs;
- SmallVector<TypeSize, 16> Offsets;
- ComputePTXValueVTs(*this, DL, Ctx, CLI.CallConv, Arg.Ty, VTs, Offsets);
+ SmallVector<uint64_t, 16> Offsets;
+ ComputePTXValueVTs(*this, DL, Ctx, CLI.CallConv, Arg.Ty, VTs, Offsets,
+ VAOffset);
assert(VTs.size() == Offsets.size() && "Size mismatch");
assert(VTs.size() == ArgOuts.size() && "Size mismatch");
@@ -1503,11 +1565,30 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
};
unsigned J = 0;
- const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
+ const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
for (const unsigned NumElts : VI) {
- TypeSize Offset = Offsets[J];
+ const EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
+
+ unsigned Offset;
+ if (IsVAArg) {
+ // TODO: We may need to support vector types that can be passed
+ // as scalars in variadic arguments.
+ assert(NumElts == 1 &&
+ "Vectorization should be disabled for vaargs.");
+
+ // Align each part of the variadic argument to their type.
+ VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
+ Offset = VAOffset;
+
+ const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
+ VAOffset += DL.getTypeAllocSize(TheStoreType.getTypeForEVT(Ctx));
+ } else {
+ assert(VAOffset == 0 && "VAOffset must be 0 for non-VA args");
+ Offset = Offsets[J];
+ }
- SDValue Ptr = DAG.getObjectPtrOffset(dl, ParamSymbol, Offset);
+ SDValue Ptr =
+ DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset));
const MaybeAlign CurrentAlign = ExtendIntegerParam
? MaybeAlign(std::nullopt)
@@ -1541,6 +1622,17 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
}
}
+ // Set the size of the vararg param byte array if the callee is a variadic
+ // function and the variadic part is not empty.
+ if (VADeclareParam) {
+ SDValue DeclareParamOps[] = {VADeclareParam.getOperand(0),
+ VADeclareParam.getOperand(1),
+ VADeclareParam.getOperand(2), GetI32(VAOffset),
+ VADeclareParam.getOperand(4)};
+ DAG.MorphNodeTo(VADeclareParam.getNode(), VADeclareParam.getOpcode(),
+ VADeclareParam->getVTList(), DeclareParamOps);
+ }
+
const auto *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
const auto *CalleeF = Func ? dyn_cast<Function>(Func->getGlobal()) : nullptr;
@@ -1574,8 +1666,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// instruction.
// The prototype is embedded in a string and put as the operand for a
// CallPrototype SDNode which will print out to the value of the string.
+ const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
std::string Proto =
- getPrototype(DL, RetTy, Args, CLI.Outs, *CB, UniqueCallSite);
+ getPrototype(DL, RetTy, Args, CLI.Outs,
+ HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, *CB,
+ UniqueCallSite);
const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
const SDValue PrototypeDeclare = DAG.getNode(
NVPTXISD::CallPrototype, dl, MVT::Other,
@@ -1609,7 +1704,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SmallVector<SDValue, 16> ProxyRegOps;
if (!Ins.empty()) {
SmallVector<EVT, 16> VTs;
- SmallVector<TypeSize, 16> Offsets;
+ SmallVector<uint64_t, 16> Offsets;
ComputePTXValueVTs(*this, DL, Ctx, CLI.CallConv, RetTy, VTs, Offsets);
assert(VTs.size() == Ins.size() && "Bad value decomposition");
@@ -1634,7 +1729,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const EVT LoadVT =
ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
const EVT VecVT = getVectorizedVT(LoadVT, NumElts, Ctx);
- SDValue Ptr = DAG.getObjectPtrOffset(dl, RetSymbol, Offsets[I]);
+ SDValue Ptr =
+ DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
SDValue R = DAG.getLoad(
VecVT, dl, Call, Ptr,
@@ -3385,6 +3481,10 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerFP_ROUND(Op, DAG);
case ISD::FP_EXTEND:
return LowerFP_EXTEND(Op, DAG);
+ case ISD::VAARG:
+ return LowerVAARG(Op, DAG);
+ case ISD::VASTART:
+ return LowerVASTART(Op, DAG);
case ISD::FSHL:
case ISD::FSHR:
return lowerFSH(Op, DAG);
@@ -3464,6 +3564,63 @@ SDValue NVPTXTargetLowering::LowerADDRSPACECAST(SDValue Op,
return Op;
}
+// This function is almost a copy of SelectionDAG::expandVAArg().
+// The only
diff is that this one produces loads from local address space.
+SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
+ const TargetLowering *TLI = STI.getTargetLowering();
+ SDLoc DL(Op);
+
+ SDNode *Node = Op.getNode();
+ const Value *V = cast<SrcValueSDNode>(Node->getOperand(2))->getValue();
+ EVT VT = Node->getValueType(0);
+ auto *Ty = VT.getTypeForEVT(*DAG.getContext());
+ SDValue Tmp1 = Node->getOperand(0);
+ SDValue Tmp2 = Node->getOperand(1);
+ const MaybeAlign MA(Node->getConstantOperandVal(3));
+
+ SDValue VAListLoad = DAG.getLoad(TLI->getPointerTy(DAG.getDataLayout()), DL,
+ Tmp1, Tmp2, MachinePointerInfo(V));
+ SDValue VAList = VAListLoad;
+
+ if (MA && *MA > TLI->getMinStackArgumentAlignment()) {
+ VAList = DAG.getNode(
+ ISD::ADD, DL, VAList.getValueType(), VAList,
+ DAG.getConstant(MA->value() - 1, DL, VAList.getValueType()));
+
+ VAList = DAG.getNode(ISD::AND, DL, VAList.getValueType(), VAList,
+ DAG.getSignedConstant(-(int64_t)MA->value(), DL,
+ VAList.getValueType()));
+ }
+
+ // Increment the pointer, VAList, to the next vaarg
+ Tmp1 = DAG.getNode(ISD::ADD, DL, VAList.getValueType(), VAList,
+ DAG.getConstant(DAG.getDataLayout().getTypeAllocSize(Ty),
+ DL, VAList.getValueType()));
+
+ // Store the incremented VAList to the legalized pointer
+ Tmp1 = DAG.getStore(VAListLoad.getValue(1), DL, Tmp1, Tmp2,
+ MachinePointerInfo(V));
+
+ const Value *SrcV = Constant::getNullValue(
+ PointerType::get(*DAG.getContext(), ADDRESS_SPACE_LOCAL));
+
+ // Load the actual argument out of the pointer VAList
+ return DAG.getLoad(VT, DL, Tmp1, VAList, MachinePointerInfo(SrcV));
+}
+
+SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
+ const TargetLowering *TLI = STI.getTargetLowering();
+ SDLoc DL(Op);
+ EVT PtrVT = TLI->getPointerTy(DAG.getDataLayout());
+
+ // Store the address of unsized array <function>_vararg[] in the ap object.
+ SDValue VAReg = getParamSymbol(DAG, /* vararg */ -1, PtrVT);
+
+ const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
+ return DAG.getStore(Op.getOperand(0), DL, VAReg, Op.getOperand(1),
+ MachinePointerInfo(SV));
+}
+
static std::pair<MemSDNode *, uint32_t>
convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG,
const NVPTXSubtarget &STI) {
@@ -3884,14 +4041,16 @@ bool NVPTXTargetLowering::splitValueIntoRegisterParts(
// This creates target external symbol for a function parameter.
// Name of the symbol is composed from its index and the function name.
-SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, unsigned I,
+// Negative index corresponds to special parameter (unsized array) used for
+// passing variable arguments.
+SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int I,
EVT T) const {
StringRef SavedStr = nvTM->getStrPool().save(
getParamName(&DAG.getMachineFunction().getFunction(), I));
return DAG.getExternalSymbol(SavedStr.data(), T);
}
-SDValue NVPTXTargetLowering::getCallParamSymbol(SelectionDAG &DAG, unsigned I,
+SDValue NVPTXTargetLowering::getCallParamSymbol(SelectionDAG &DAG, int I,
EVT T) const {
const StringRef SavedStr = nvTM->getStrPool().save("param" + Twine(I));
return DAG.getExternalSymbol(SavedStr.data(), T);
@@ -3901,8 +4060,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
- assert(!isVarArg && "Vararg functions lowered in ExpandVariadics");
-
const DataLayout &DL = DAG.getDataLayout();
LLVMContext &Ctx = *DAG.getContext();
auto PtrVT = getPointerTy(DAG.getDataLayout());
@@ -3978,7 +4135,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
InVals.push_back(P);
} else {
SmallVector<EVT, 16> VTs;
- SmallVector<TypeSize, 16> Offsets;
+ SmallVector<uint64_t, 16> Offsets;
ComputePTXValueVTs(*this, DL, Ctx, CallConv, Ty, VTs, Offsets);
assert(VTs.size() == ArgIns.size() && "Size mismatch");
assert(VTs.size() == Offsets.size() && "Size mismatch");
@@ -3993,7 +4150,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
const EVT VecVT = getVectorizedVT(LoadVT, NumElts, Ctx);
- SDValue VecAddr = DAG.getObjectPtrOffset(dl, ArgSymbol, Offsets[I]);
+ SDValue VecAddr = DAG.getObjectPtrOffset(
+ dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
const Align PartAlign = commonAlignment(ArgAlign, Offsets[I]);
const unsigned AS = IsKernel ? NVPTX::AddressSpace::EntryParam
@@ -4027,8 +4185,6 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
const SmallVectorImpl<ISD::OutputArg> &Outs,
const SmallVectorImpl<SDValue> &OutVals,
const SDLoc &dl, SelectionDAG &DAG) const {
- assert(!isVarArg && "Vararg functions lowered in ExpandVariadics");
-
const Function &F = DAG.getMachineFunction().getFunction();
Type *RetTy = F.getReturnType();
@@ -4051,7 +4207,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
SmallVector<EVT, 16> VTs;
- SmallVector<TypeSize, 16> Offsets;
+ SmallVector<uint64_t, 16> Offsets;
ComputePTXValueVTs(*this, DL, Ctx, CallConv, RetTy, VTs, Offsets);
assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
@@ -4077,7 +4233,8 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
SDValue Val = getBuildVectorizedValue(
NumElts, dl, DAG, [&](unsigned K) { return GetRetVal(I + K); });
- SDValue Ptr = DAG.getObjectPtrOffset(dl, RetSymbol, Offsets[I]);
+ SDValue Ptr =
+ DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
Chain = DAG.getStore(Chain, dl, Val, Ptr,
MachinePointerInfo(NVPTX::AddressSpace::DeviceParam),
@@ -5384,11 +5541,20 @@ void NVPTXTargetLowering::getTgtMemIntrinsic(
}
// Helper for getting a function parameter name. Name is composed from
-// its index and the function name.
+// its index and the function name. Negative index corresponds to special
+// parameter (unsized array) used for passing variable arguments.
std::string NVPTXTargetLowering::getParamName(const Function *F,
- unsigned Idx) const {
- return (getTargetMachine().getSymbol(F)->getName() + "_param_" + Twine(Idx))
- .str();
+ int Idx) const {
+ std::string ParamName;
+ raw_string_ostream ParamStr(ParamName);
+
+ ParamStr << getTargetMachine().getSymbol(F)->getName();
+ if (Idx < 0)
+ ParamStr << "_vararg";
+ else
+ ParamStr << "_param_" << Idx;
+
+ return ParamName;
}
/// isLegalAddressingMode - Return true if the addressing mode represented
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 76892c229a842..0e8dd6056af81 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -37,8 +37,9 @@ class NVPTXTargetLowering : public TargetLowering {
unsigned Intrinsic) const override;
// Helper for getting a function parameter name. Name is composed from
- // its index and the function name.
- std::string getParamName(const Function *F, unsigned Idx) const;
+ // its index and the function name. Negative index corresponds to special
+ // parameter (unsized array) used for passing variable arguments.
+ std::string getParamName(const Function *F, int Idx) const;
/// isLegalAddressingMode - Return true if the addressing mode represented
/// by AM is legal for this target, for a load/store of the specified type
@@ -84,6 +85,7 @@ class NVPTXTargetLowering : public TargetLowering {
std::string getPrototype(const DataLayout &DL, Type *, const ArgListTy &,
const SmallVectorImpl<ISD::OutputArg> &,
+ std::optional<unsigned> FirstVAArg,
const CallBase &CB, unsigned UniqueCallSite) const;
SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
@@ -192,8 +194,8 @@ class NVPTXTargetLowering : public TargetLowering {
const NVPTXSubtarget &STI; // cache the subtarget here
mutable unsigned GlobalUniqueCallSite;
- SDValue getParamSymbol(SelectionDAG &DAG, unsigned I, EVT T) const;
- SDValue getCallParamSymbol(SelectionDAG &DAG, unsigned I, EVT T) const;
+ SDValue getParamSymbol(SelectionDAG &DAG, int I, EVT T) const;
+ SDValue getCallParamSymbol(SelectionDAG &DAG, int I, EVT T) const;
SDValue LowerADDRSPACECAST(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const;
@@ -226,6 +228,9 @@ class NVPTXTargetLowering : public TargetLowering {
SDValue LowerShiftRightParts(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;
+
SDValue LowerCopyToReg_128(SDValue Op, SelectionDAG &DAG) const;
unsigned getNumRegisters(LLVMContext &Context, EVT VT,
std::optional<MVT> RegisterVT) const override;
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index e632204a444d5..1df5d326f63a6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -334,6 +334,14 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
bool hasNativeBF16Support(int Opcode) const;
+ // Get maximum value of required alignments among the supported data types.
+ // From the PTX ISA doc, section 8.2.3:
+ // The memory consistency model relates operations executed on memory
+ // locations with scalar data-types, which have a maximum size and alignment
+ // of 64 bits. Memory operations with a vector data-type are modelled as a
+ // set of equivalent memory operations with a scalar data-type, executed in
+ // an unspecified order on the elements in the vector.
+ unsigned getMaxRequiredAlignment() const { return 8; }
// Get the smallest cmpxchg word size that the hardware supports.
unsigned getMinCmpXchgSizeInBits() const { return 32; }
More information about the llvm-branch-commits
mailing list