[llvm] [NVPTX] Further cleanup call isel (PR #146411)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 30 12:35:09 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-nvptx
Author: Alex MacLean (AlexMaclean)
<details>
<summary>Changes</summary>
---
Patch is 315.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146411.diff
15 Files Affected:
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+150-153)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+10-6)
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+23-35)
- (modified) llvm/test/CodeGen/NVPTX/cmpxchg-sm60.ll (+540-540)
- (modified) llvm/test/CodeGen/NVPTX/cmpxchg-sm70.ll (+540-540)
- (modified) llvm/test/CodeGen/NVPTX/cmpxchg-sm90.ll (+540-540)
- (modified) llvm/test/CodeGen/NVPTX/cmpxchg.ll (+120-120)
- (modified) llvm/test/CodeGen/NVPTX/convert-int-sm20.ll (+3-3)
- (modified) llvm/test/CodeGen/NVPTX/extractelement.ll (+9-12)
- (modified) llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll (+4-4)
- (modified) llvm/test/CodeGen/NVPTX/lower-args.ll (+2-2)
- (modified) llvm/test/CodeGen/NVPTX/misched_func_call.ll (+6-6)
- (modified) llvm/test/CodeGen/NVPTX/st-param-imm.ll (+30-30)
- (modified) llvm/test/CodeGen/NVPTX/unaligned-param-load-store.ll (+96-96)
- (modified) llvm/test/CodeGen/NVPTX/variadics-backend.ll (+1-1)
``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d9192fbfceff1..a41b094faa8d6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -28,6 +28,7 @@
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineJumpTableInfo.h"
#include "llvm/CodeGen/MachineMemOperand.h"
+#include "llvm/CodeGen/Register.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/CodeGen/TargetCallingConv.h"
@@ -390,35 +391,27 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
/// and promote them to a larger size if they're not.
///
/// The promoted type is placed in \p PromoteVT if the function returns true.
-static std::optional<MVT> PromoteScalarIntegerPTX(const EVT &VT) {
+static EVT promoteScalarIntegerPTX(const EVT VT) {
if (VT.isScalarInteger()) {
- MVT PromotedVT;
switch (PowerOf2Ceil(VT.getFixedSizeInBits())) {
default:
llvm_unreachable(
"Promotion is not suitable for scalars of size larger than 64-bits");
case 1:
- PromotedVT = MVT::i1;
- break;
+ return MVT::i1;
case 2:
case 4:
case 8:
- PromotedVT = MVT::i8;
- break;
+ return MVT::i8;
case 16:
- PromotedVT = MVT::i16;
- break;
+ return MVT::i16;
case 32:
- PromotedVT = MVT::i32;
- break;
+ return MVT::i32;
case 64:
- PromotedVT = MVT::i64;
- break;
+ return MVT::i64;
}
- if (VT != PromotedVT)
- return PromotedVT;
}
- return std::nullopt;
+ return VT;
}
// Check whether we can merge loads/stores of some of the pieces of a
@@ -1053,10 +1046,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
break;
MAKE_CASE(NVPTXISD::RET_GLUE)
- MAKE_CASE(NVPTXISD::DeclareParam)
+ MAKE_CASE(NVPTXISD::DeclareArrayParam)
MAKE_CASE(NVPTXISD::DeclareScalarParam)
- MAKE_CASE(NVPTXISD::DeclareRet)
- MAKE_CASE(NVPTXISD::DeclareRetParam)
MAKE_CASE(NVPTXISD::CALL)
MAKE_CASE(NVPTXISD::LoadParam)
MAKE_CASE(NVPTXISD::LoadParamV2)
@@ -1162,8 +1153,8 @@ SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
}
std::string NVPTXTargetLowering::getPrototype(
- const DataLayout &DL, Type *retTy, const ArgListTy &Args,
- const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign RetAlign,
+ const DataLayout &DL, Type *RetTy, const ArgListTy &Args,
+ const SmallVectorImpl<ISD::OutputArg> &Outs,
std::optional<unsigned> FirstVAArg, const CallBase &CB,
unsigned UniqueCallSite) const {
auto PtrVT = getPointerTy(DL);
@@ -1172,22 +1163,22 @@ std::string NVPTXTargetLowering::getPrototype(
raw_string_ostream O(Prototype);
O << "prototype_" << UniqueCallSite << " : .callprototype ";
- if (retTy->isVoidTy()) {
+ if (RetTy->isVoidTy()) {
O << "()";
} else {
O << "(";
- if (shouldPassAsArray(retTy)) {
- assert(RetAlign && "RetAlign must be set for non-void return types");
- O << ".param .align " << RetAlign->value() << " .b8 _["
- << DL.getTypeAllocSize(retTy) << "]";
- } else if (retTy->isFloatingPointTy() || retTy->isIntegerTy()) {
+ if (shouldPassAsArray(RetTy)) {
+ const Align RetAlign = getArgumentAlignment(&CB, RetTy, 0, DL);
+ O << ".param .align " << RetAlign.value() << " .b8 _["
+ << DL.getTypeAllocSize(RetTy) << "]";
+ } else if (RetTy->isFloatingPointTy() || RetTy->isIntegerTy()) {
unsigned size = 0;
- if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
+ if (auto *ITy = dyn_cast<IntegerType>(RetTy)) {
size = ITy->getBitWidth();
} else {
- assert(retTy->isFloatingPointTy() &&
+ assert(RetTy->isFloatingPointTy() &&
"Floating point type expected here");
- size = retTy->getPrimitiveSizeInBits();
+ size = RetTy->getPrimitiveSizeInBits();
}
// PTX ABI requires all scalar return values to be at least 32
// bits in size. fp16 normally uses .b16 as its storage type in
@@ -1195,7 +1186,7 @@ std::string NVPTXTargetLowering::getPrototype(
size = promoteScalarArgumentSize(size);
O << ".param .b" << size << " _";
- } else if (isa<PointerType>(retTy)) {
+ } else if (isa<PointerType>(RetTy)) {
O << ".param .b" << PtrVT.getSizeInBits() << " _";
} else {
llvm_unreachable("Unknown return type");
@@ -1256,7 +1247,7 @@ std::string NVPTXTargetLowering::getPrototype(
if (FirstVAArg)
O << (first ? "" : ",") << " .param .align "
- << STI.getMaxRequiredAlignment() << " .b8 _[]\n";
+ << STI.getMaxRequiredAlignment() << " .b8 _[]";
O << ")";
if (shouldEmitPTXNoReturn(&CB, *nvTM))
O << " .noreturn";
@@ -1442,6 +1433,21 @@ static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags) {
return ISD::ANY_EXTEND;
}
+static SDValue correctParamType(SDValue V, EVT ExpectedVT,
+ ISD::ArgFlagsTy Flags, SelectionDAG &DAG,
+ SDLoc dl) {
+ const EVT ActualVT = V.getValueType();
+ assert((ActualVT == ExpectedVT ||
+ (ExpectedVT.isInteger() && ActualVT.isInteger())) &&
+ "Non-integer argument type size mismatch");
+ if (ExpectedVT.bitsGT(ActualVT))
+ return DAG.getNode(getExtOpcode(Flags), dl, ExpectedVT, V);
+ if (ExpectedVT.bitsLT(ActualVT))
+ return DAG.getNode(ISD::TRUNCATE, dl, ExpectedVT, V);
+
+ return V;
+}
+
SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SmallVectorImpl<SDValue> &InVals) const {
@@ -1505,9 +1511,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
"Outs and OutVals must be the same size");
// Declare the .params or .reg need to pass values
// to the function
- for (const auto [ArgI, Arg] : llvm::enumerate(Args)) {
- const auto ArgOuts = AllOuts.take_while(
- [ArgI = ArgI](auto O) { return O.OrigArgIndex == ArgI; });
+ for (const auto E : llvm::enumerate(Args)) {
+ const auto ArgI = E.index();
+ const auto Arg = E.value();
+ const auto ArgOuts =
+ AllOuts.take_while([&](auto O) { return O.OrigArgIndex == ArgI; });
const auto ArgOutVals = AllOutVals.take_front(ArgOuts.size());
AllOuts = AllOuts.drop_front(ArgOuts.size());
AllOutVals = AllOutVals.drop_front(ArgOuts.size());
@@ -1515,6 +1523,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const bool IsVAArg = (ArgI >= FirstVAArg);
const bool IsByVal = Arg.IsByVal;
+ const SDValue ParamSymbol =
+ getCallParamSymbol(DAG, IsVAArg ? FirstVAArg : ArgI, MVT::i32);
+
SmallVector<EVT, 16> VTs;
SmallVector<uint64_t, 16> Offsets;
@@ -1525,38 +1536,43 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
assert(VTs.size() == Offsets.size() && "Size mismatch");
assert((IsByVal || VTs.size() == ArgOuts.size()) && "Size mismatch");
- Align ArgAlign;
- if (IsByVal) {
- // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
- // so we don't need to worry whether it's naturally aligned or not.
- // See TargetLowering::LowerCallTo().
- Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
- ArgAlign = getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
- InitialAlign, DL);
- if (IsVAArg)
- VAOffset = alignTo(VAOffset, ArgAlign);
- } else {
- ArgAlign = getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
- }
+ const Align ArgAlign = [&]() {
+ if (IsByVal) {
+ // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
+ // so we don't need to worry whether it's naturally aligned or not.
+ // See TargetLowering::LowerCallTo().
+ const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
+ const Align ByValAlign = getFunctionByValParamAlign(
+ CB->getCalledFunction(), ETy, InitialAlign, DL);
+ if (IsVAArg)
+ VAOffset = alignTo(VAOffset, ByValAlign);
+ return ByValAlign;
+ }
+ return getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
+ }();
const unsigned TypeSize = DL.getTypeAllocSize(ETy);
assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
"type size mismatch");
- const bool PassAsArray = IsByVal || shouldPassAsArray(Arg.Ty);
- if (IsVAArg) {
- if (ArgI == FirstVAArg) {
- VADeclareParam = Chain =
- DAG.getNode(NVPTXISD::DeclareParam, dl, {MVT::Other, MVT::Glue},
- {Chain, GetI32(STI.getMaxRequiredAlignment()),
- GetI32(ArgI), GetI32(1), InGlue});
+ const std::optional<SDValue> ArgDeclare = [&]() -> std::optional<SDValue> {
+ if (IsVAArg) {
+ if (ArgI == FirstVAArg) {
+ VADeclareParam = DAG.getNode(
+ NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
+ {Chain, ParamSymbol, GetI32(STI.getMaxRequiredAlignment()),
+ GetI32(0), InGlue});
+ return VADeclareParam;
+ }
+ return std::nullopt;
+ }
+ if (IsByVal || shouldPassAsArray(Arg.Ty)) {
+ // declare .param .align <align> .b8 .param<n>[<size>];
+ return DAG.getNode(NVPTXISD::DeclareArrayParam, dl,
+ {MVT::Other, MVT::Glue},
+ {Chain, ParamSymbol, GetI32(ArgAlign.value()),
+ GetI32(TypeSize), InGlue});
}
- } else if (PassAsArray) {
- // declare .param .align <align> .b8 .param<n>[<size>];
- Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, {MVT::Other, MVT::Glue},
- {Chain, GetI32(ArgAlign.value()), GetI32(ArgI),
- GetI32(TypeSize), InGlue});
- } else {
assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
// declare .param .b<size> .param<n>;
@@ -1568,11 +1584,14 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
? promoteScalarArgumentSize(TypeSize * 8)
: TypeSize * 8;
- Chain =
- DAG.getNode(NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
- {Chain, GetI32(ArgI), GetI32(PromotedSize), InGlue});
+ return DAG.getNode(NVPTXISD::DeclareScalarParam, dl,
+ {MVT::Other, MVT::Glue},
+ {Chain, ParamSymbol, GetI32(PromotedSize), InGlue});
+ }();
+ if (ArgDeclare) {
+ Chain = ArgDeclare->getValue(0);
+ InGlue = ArgDeclare->getValue(1);
}
- InGlue = Chain.getValue(1);
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter
// than 32-bits are sign extended or zero extended, depending on
@@ -1594,8 +1613,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
} else {
StVal = ArgOutVals[I];
- if (auto PromotedVT = PromoteScalarIntegerPTX(StVal.getValueType())) {
- StVal = DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, *PromotedVT,
+ auto PromotedVT = promoteScalarIntegerPTX(StVal.getValueType());
+ if (PromotedVT != StVal.getValueType()) {
+ StVal = DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, PromotedVT,
StVal);
}
}
@@ -1619,12 +1639,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
unsigned J = 0;
for (const unsigned NumElts : VectorInfo) {
const int CurOffset = Offsets[J];
- EVT EltVT = VTs[J];
+ EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
const Align PartAlign = commonAlignment(ArgAlign, CurOffset);
- if (auto PromotedVT = PromoteScalarIntegerPTX(EltVT))
- EltVT = *PromotedVT;
-
// If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
// scalar store. In such cases, fall back to byte stores.
if (NumElts == 1 && !IsVAArg && PartAlign < DAG.getEVTAlign(EltVT)) {
@@ -1695,27 +1712,26 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
}
GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
- MaybeAlign RetAlign = std::nullopt;
// Handle Result
if (!Ins.empty()) {
- RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
-
- // Declare
- // .param .align N .b8 retval0[<size-in-bytes>], or
- // .param .b<size-in-bits> retval0
- const unsigned ResultSize = DL.getTypeAllocSizeInBits(RetTy);
- if (!shouldPassAsArray(RetTy)) {
- const unsigned PromotedResultSize = promoteScalarArgumentSize(ResultSize);
- Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, {MVT::Other, MVT::Glue},
- {Chain, GetI32(PromotedResultSize), InGlue});
- InGlue = Chain.getValue(1);
- } else {
- Chain = DAG.getNode(
- NVPTXISD::DeclareRetParam, dl, {MVT::Other, MVT::Glue},
- {Chain, GetI32(RetAlign->value()), GetI32(ResultSize / 8), InGlue});
- InGlue = Chain.getValue(1);
- }
+ const SDValue RetDeclare = [&]() {
+ const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
+ const unsigned ResultSize = DL.getTypeAllocSizeInBits(RetTy);
+ if (shouldPassAsArray(RetTy)) {
+ const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
+ return DAG.getNode(NVPTXISD::DeclareArrayParam, dl,
+ {MVT::Other, MVT::Glue},
+ {Chain, RetSymbol, GetI32(RetAlign.value()),
+ GetI32(ResultSize / 8), InGlue});
+ }
+ const auto PromotedResultSize = promoteScalarArgumentSize(ResultSize);
+ return DAG.getNode(
+ NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
+ {Chain, RetSymbol, GetI32(PromotedResultSize), InGlue});
+ }();
+ Chain = RetDeclare.getValue(0);
+ InGlue = RetDeclare.getValue(1);
}
const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
@@ -1760,7 +1776,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// The prototype is embedded in a string and put as the operand for a
// CallPrototype SDNode which will print out to the value of the string.
std::string Proto =
- getPrototype(DL, RetTy, Args, CLI.Outs, RetAlign,
+ getPrototype(DL, RetTy, Args, CLI.Outs,
HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, *CB,
UniqueCallSite);
const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
@@ -1773,11 +1789,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
if (ConvertToIndirectCall) {
// Copy the function ptr to a ptx register and use the register to call the
// function.
- EVT DestVT = Callee.getValueType();
- MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
+ const MVT DestVT = Callee.getValueType().getSimpleVT();
+ MachineRegisterInfo &MRI = DAG.getMachineFunction().getRegInfo();
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- unsigned DestReg =
- RegInfo.createVirtualRegister(TLI.getRegClassFor(DestVT.getSimpleVT()));
+ Register DestReg = MRI.createVirtualRegister(TLI.getRegClassFor(DestVT));
auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee);
Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
}
@@ -1810,7 +1825,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
assert(VTs.size() == Ins.size() && "Bad value decomposition");
- assert(RetAlign && "RetAlign is guaranteed to be set");
+ const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
// 32-bits are sign extended or zero extended, depending on whether
@@ -1818,17 +1833,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const bool ExtendIntegerRetVal =
RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
- const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, *RetAlign);
+ const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
unsigned I = 0;
for (const unsigned VectorizedSize : VectorInfo) {
- EVT TheLoadType = VTs[I];
+ EVT TheLoadType = promoteScalarIntegerPTX(VTs[I]);
EVT EltType = Ins[I].VT;
- const Align EltAlign = commonAlignment(*RetAlign, Offsets[I]);
+ const Align EltAlign = commonAlignment(RetAlign, Offsets[I]);
- if (auto PromotedVT = PromoteScalarIntegerPTX(TheLoadType)) {
- TheLoadType = *PromotedVT;
- EltType = *PromotedVT;
- }
+ if (TheLoadType != VTs[I])
+ EltType = TheLoadType;
if (ExtendIntegerRetVal) {
TheLoadType = MVT::i32;
@@ -1898,13 +1911,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
continue;
}
- SDValue Ret = DAG.getNode(
- NVPTXISD::ProxyReg, dl,
- {ProxyRegOps[I].getSimpleValueType(), MVT::Other, MVT::Glue},
- {Chain, ProxyRegOps[I], InGlue});
-
- Chain = Ret.getValue(1);
- InGlue = Ret.getValue(2);
+ SDValue Ret =
+ DAG.getNode(NVPTXISD::ProxyReg, dl, ProxyRegOps[I].getSimpleValueType(),
+ {Chain, ProxyRegOps[I]});
const EVT ExpectedVT = Ins[I].VT;
if (!Ret.getValueType().bitsEq(ExpectedVT)) {
@@ -1914,14 +1923,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
}
for (SDValue &T : TempProxyRegOps) {
- SDValue Repl = DAG.getNode(NVPTXISD::ProxyReg, dl,
- {T.getSimpleValueType(), MVT::Other, MVT::Glue},
- {Chain, T.getOperand(0), InGlue});
+ SDValue Repl = DAG.getNode(NVPTXISD::ProxyReg, dl, T.getSimpleValueType(),
+ {Chain, T.getOperand(0)});
DAG.ReplaceAllUsesWith(T, Repl);
DAG.RemoveDeadNode(T.getNode());
-
- Chain = Repl.getValue(1);
- InGlue = Repl.getValue(2);
}
// set isTailCall to false for now, until we figure out how to express
@@ -3292,11 +3297,17 @@ bool NVPTXTargetLowering::splitValueIntoRegisterParts(
// Name of the symbol is composed from its index and the function name.
// Negative index corresponds to special parameter (unsized array) used for
// passing variable arguments.
-SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx,
- EVT v) const {
+SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int I,
+ EVT T) const {
StringRef SavedStr = nvTM->getStrPool().save(
- getParamName(&DAG.getMachineFunction().getFunction(), idx));
- return DAG.getExternalSymbol(SavedStr.data(), v);
+ getParamName(&DAG.getMachineFunction().getFunction(), I));
+ return DAG.getExternalSymbol(SavedStr.data(), T);
+}
+
+SDValue NVPTXTargetLowering::getCallParamSymbol(SelectionDAG &DAG, int I,
+ EVT T) const {
+ const StringRef SavedStr = nvTM->getStrPool().save("param" + Twine(I));
+ return DAG.getExternalSymbol(SavedStr.data(), T);
}
SDValue NVPTXTargetLowering::LowerFormalArguments(
@@ -3393,8 +3404,11 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
const unsigned PackingAmt =
LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;
- const EVT VecVT = EVT::getVectorVT(
- F->getContext(), LoadVT.getScalarType(), NumElts * PackingAmt);
+ const EVT VecVT =
+ NumElts == 1
+ ? LoadVT
+ : EVT::getVectorVT(F->getContext(), LoadVT.getScalarType(),
+ NumElts * PackingAmt);
SDValue VecAddr = DAG.getObjectPtrOffset(
dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
@@ -3408,22 +3422,16 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
if (P....
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/146411
More information about the llvm-commits
mailing list