[llvm] [NVPTX][NFC] Refactor and cleanup NVPTXISelLowering call lowering 2/n (PR #137666)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 28 09:39:53 PDT 2025
https://github.com/AlexMaclean created https://github.com/llvm/llvm-project/pull/137666
None
>From 6b242a4994fb39cf998a99145cccd596f6b8750d Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Fri, 25 Apr 2025 21:31:00 +0000
Subject: [PATCH 1/3] [NVPTX][NFC] Refactor and cleanup NVPTXISelLowering 2/n
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 475 +++++++++-----------
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 33 +-
2 files changed, 220 insertions(+), 288 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index c41741ed10232..b287822e61db9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -343,33 +343,35 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
/// and promote them to a larger size if they're not.
///
/// The promoted type is placed in \p PromoteVT if the function returns true.
-static bool PromoteScalarIntegerPTX(const EVT &VT, MVT *PromotedVT) {
+static std::optional<MVT> PromoteScalarIntegerPTX(const EVT &VT) {
if (VT.isScalarInteger()) {
+ MVT PromotedVT;
switch (PowerOf2Ceil(VT.getFixedSizeInBits())) {
default:
llvm_unreachable(
"Promotion is not suitable for scalars of size larger than 64-bits");
case 1:
- *PromotedVT = MVT::i1;
+ PromotedVT = MVT::i1;
break;
case 2:
case 4:
case 8:
- *PromotedVT = MVT::i8;
+ PromotedVT = MVT::i8;
break;
case 16:
- *PromotedVT = MVT::i16;
+ PromotedVT = MVT::i16;
break;
case 32:
- *PromotedVT = MVT::i32;
+ PromotedVT = MVT::i32;
break;
case 64:
- *PromotedVT = MVT::i64;
+ PromotedVT = MVT::i64;
break;
}
- return EVT(*PromotedVT) != VT;
+ if (VT != PromotedVT)
+ return PromotedVT;
}
- return false;
+ return std::nullopt;
}
// Check whether we can merge loads/stores of some of the pieces of a
@@ -1451,8 +1453,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SelectionDAG &DAG = CLI.DAG;
SDLoc dl = CLI.DL;
- SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs;
- SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
SDValue Chain = CLI.Chain;
SDValue Callee = CLI.Callee;
@@ -1462,6 +1462,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const CallBase *CB = CLI.CB;
const DataLayout &DL = DAG.getDataLayout();
+ const auto GetI32 = [&](const unsigned I) {
+ return DAG.getConstant(I, dl, MVT::i32);
+ };
+
// Variadic arguments.
//
// Normally, for each argument, we declare a param scalar or a param
@@ -1479,7 +1483,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// vararg byte array.
SDValue VADeclareParam; // vararg byte array
- unsigned FirstVAArg = CLI.NumFixedArgs; // position of the first variadic
+ const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
unsigned VAOffset = 0; // current offset in the param array
const unsigned UniqueCallSite = GlobalUniqueCallSite++;
@@ -1487,7 +1491,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
SDValue InGlue = Chain.getValue(1);
- unsigned ParamCount = 0;
// Args.size() and Outs.size() need not match.
// Outs.size() will be larger
// * if there is an aggregate argument with multiple fields (each field
@@ -1497,76 +1500,81 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// individually present in Outs.
// So a different index should be used for indexing into Outs/OutVals.
// See similar issue in LowerFormalArguments.
- unsigned OIdx = 0;
+ auto AllOuts = ArrayRef(CLI.Outs);
+ auto AllOutVals = ArrayRef(CLI.OutVals);
+ assert(AllOuts.size() == AllOutVals.size() &&
+ "Outs and OutVals must be the same size");
// Declare the .params or .reg need to pass values
// to the function
- for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
- EVT VT = Outs[OIdx].VT;
- Type *Ty = Args[i].Ty;
- bool IsVAArg = (i >= CLI.NumFixedArgs);
- bool IsByVal = Outs[OIdx].Flags.isByVal();
+ for (const auto [I, Arg] : llvm::enumerate(Args)) {
+ const auto ArgOuts =
+ AllOuts.take_while([I = I](auto O) { return O.OrigArgIndex == I; });
+ const auto ArgOutVals = AllOutVals.take_front(ArgOuts.size());
+ AllOuts = AllOuts.drop_front(ArgOuts.size());
+ AllOutVals = AllOutVals.drop_front(ArgOuts.size());
+
+ const bool IsVAArg = (I >= CLI.NumFixedArgs);
+ const bool IsByVal = Arg.IsByVal;
SmallVector<EVT, 16> VTs;
SmallVector<uint64_t, 16> Offsets;
- assert((!IsByVal || Args[i].IndirectType) &&
+ assert((!IsByVal || Arg.IndirectType) &&
"byval arg must have indirect type");
- Type *ETy = (IsByVal ? Args[i].IndirectType : Ty);
+ Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset);
+ assert(VTs.size() == Offsets.size() && "Size mismatch");
+ assert((IsByVal || VTs.size() == ArgOuts.size()) && "Size mismatch");
Align ArgAlign;
if (IsByVal) {
// The ByValAlign in the Outs[OIdx].Flags is always set at this point,
// so we don't need to worry whether it's naturally aligned or not.
// See TargetLowering::LowerCallTo().
- Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
+ Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
ArgAlign = getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
InitialAlign, DL);
if (IsVAArg)
VAOffset = alignTo(VAOffset, ArgAlign);
} else {
- ArgAlign = getArgumentAlignment(CB, Ty, ParamCount + 1, DL);
+ ArgAlign = getArgumentAlignment(CB, Arg.Ty, I + 1, DL);
}
- unsigned TypeSize =
- (IsByVal ? Outs[OIdx].Flags.getByValSize() : DL.getTypeAllocSize(Ty));
- SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
+ const unsigned TypeSize = DL.getTypeAllocSize(ETy);
+ assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
+ "type size mismatch");
bool NeedAlign; // Does argument declaration specify alignment?
- const bool PassAsArray = IsByVal || shouldPassAsArray(Ty);
+ const bool PassAsArray = IsByVal || shouldPassAsArray(Arg.Ty);
if (IsVAArg) {
- if (ParamCount == FirstVAArg) {
- SDValue DeclareParamOps[] = {
- Chain, DAG.getConstant(STI.getMaxRequiredAlignment(), dl, MVT::i32),
- DAG.getConstant(ParamCount, dl, MVT::i32),
- DAG.getConstant(1, dl, MVT::i32), InGlue};
- VADeclareParam = Chain = DAG.getNode(NVPTXISD::DeclareParam, dl,
- DeclareParamVTs, DeclareParamOps);
+ if (I == FirstVAArg) {
+ VADeclareParam = Chain =
+ DAG.getNode(NVPTXISD::DeclareParam, dl, {MVT::Other, MVT::Glue},
+ {Chain, GetI32(STI.getMaxRequiredAlignment()),
+ GetI32(I), GetI32(1), InGlue});
}
NeedAlign = PassAsArray;
} else if (PassAsArray) {
// declare .param .align <align> .b8 .param<n>[<size>];
- SDValue DeclareParamOps[] = {
- Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32),
- DAG.getConstant(ParamCount, dl, MVT::i32),
- DAG.getConstant(TypeSize, dl, MVT::i32), InGlue};
- Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
- DeclareParamOps);
+ Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, {MVT::Other, MVT::Glue},
+ {Chain, GetI32(ArgAlign.value()), GetI32(I),
+ GetI32(TypeSize), InGlue});
NeedAlign = true;
} else {
+ assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
// declare .param .b<size> .param<n>;
- if (VT.isInteger() || VT.isFloatingPoint()) {
- // PTX ABI requires integral types to be at least 32 bits in
- // size. FP16 is loaded/stored using i16, so it's handled
- // here as well.
- TypeSize = promoteScalarArgumentSize(TypeSize * 8) / 8;
- }
- SDValue DeclareScalarParamOps[] = {
- Chain, DAG.getConstant(ParamCount, dl, MVT::i32),
- DAG.getConstant(TypeSize * 8, dl, MVT::i32),
- DAG.getConstant(0, dl, MVT::i32), InGlue};
- Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
- DeclareScalarParamOps);
+
+ // PTX ABI requires integral types to be at least 32 bits in
+ // size. FP16 is loaded/stored using i16, so it's handled
+ // here as well.
+ const unsigned PromotedSize =
+ (ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint())
+ ? promoteScalarArgumentSize(TypeSize * 8)
+ : TypeSize * 8;
+
+ Chain = DAG.getNode(
+ NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
+ {Chain, GetI32(I), GetI32(PromotedSize), GetI32(0), InGlue});
NeedAlign = false;
}
InGlue = Chain.getValue(1);
@@ -1575,8 +1583,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// than 32-bits are sign extended or zero extended, depending on
// whether they are signed or unsigned types. This case applies
// only to scalar parameters and not to aggregate values.
- bool ExtendIntegerParam =
- Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32;
+ const bool ExtendIntegerParam =
+ Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
SmallVector<SDValue, 6> StoreOperands;
@@ -1587,34 +1595,34 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
if (NeedAlign)
PartAlign = commonAlignment(ArgAlign, CurOffset);
- SDValue StVal = OutVals[OIdx];
-
- MVT PromotedVT;
- if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) {
- EltVT = EVT(PromotedVT);
- }
- if (PromoteScalarIntegerPTX(StVal.getValueType(), &PromotedVT)) {
- llvm::ISD::NodeType Ext =
- Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
- StVal = DAG.getNode(Ext, dl, PromotedVT, StVal);
- }
+ if (auto PromotedVT = PromoteScalarIntegerPTX(EltVT))
+ EltVT = *PromotedVT;
+ SDValue StVal;
if (IsByVal) {
- auto MPI = refinePtrAS(StVal, DAG, DL, *this);
- const EVT PtrVT = StVal.getValueType();
- SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
- DAG.getConstant(CurOffset, dl, PtrVT));
+ SDValue Ptr = ArgOutVals[0];
+ auto MPI = refinePtrAS(Ptr, DAG, DL, *this);
+ SDValue SrcAddr =
+ DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(CurOffset));
StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign);
- } else if (ExtendIntegerParam) {
+ } else {
+ StVal = ArgOutVals[J];
+
+ if (auto PromotedVT = PromoteScalarIntegerPTX(StVal.getValueType())) {
+ llvm::ISD::NodeType Ext =
+ ArgOuts[J].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+ StVal = DAG.getNode(Ext, dl, *PromotedVT, StVal);
+ }
+ }
+
+ if (ExtendIntegerParam) {
assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
// zext/sext to i32
- StVal = DAG.getNode(Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
+ StVal = DAG.getNode(ArgOuts[J].Flags.isSExt() ? ISD::SIGN_EXTEND
: ISD::ZERO_EXTEND,
dl, MVT::i32, StVal);
- }
-
- if (!ExtendIntegerParam && EltVT.getSizeInBits() < 16) {
+ } else if (EltVT.getSizeInBits() < 16) {
// Use 16-bit registers for small stores as it's the
// smallest general purpose register size supported by NVPTX.
StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
@@ -1623,36 +1631,28 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
// scalar store. In such cases, fall back to byte stores.
if (VectorInfo[J] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() &&
- PartAlign.value() <
- DL.getABITypeAlign(EltVT.getTypeForEVT(*DAG.getContext()))) {
+ PartAlign.value() < DAG.getEVTAlign(EltVT)) {
assert(StoreOperands.empty() && "Unfinished preceeding store.");
Chain = LowerUnalignedStoreParam(
DAG, Chain, IsByVal ? CurOffset + VAOffset : CurOffset, EltVT,
- StVal, InGlue, ParamCount, dl);
+ StVal, InGlue, I, dl);
// LowerUnalignedStoreParam took care of inserting the necessary nodes
// into the SDAG, so just move on to the next element.
- if (!IsByVal)
- ++OIdx;
continue;
}
// New store.
if (VectorInfo[J] & PVF_FIRST) {
- assert(StoreOperands.empty() && "Unfinished preceding store.");
- StoreOperands.push_back(Chain);
- StoreOperands.push_back(
- DAG.getConstant(IsVAArg ? FirstVAArg : ParamCount, dl, MVT::i32));
-
- if (!IsByVal && IsVAArg) {
+ if (!IsByVal && IsVAArg)
// Align each part of the variadic argument to their type.
- VAOffset = alignTo(VAOffset, DL.getABITypeAlign(EltVT.getTypeForEVT(
- *DAG.getContext())));
- }
+ VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
- StoreOperands.push_back(DAG.getConstant(
- IsByVal ? CurOffset + VAOffset : (IsVAArg ? VAOffset : CurOffset),
- dl, MVT::i32));
+ assert(StoreOperands.empty() && "Unfinished preceding store.");
+ StoreOperands.append(
+ {Chain, GetI32(IsVAArg ? FirstVAArg : I),
+ GetI32(IsByVal ? CurOffset + VAOffset
+ : (IsVAArg ? VAOffset : CurOffset))});
}
// Record the value to store.
@@ -1699,13 +1699,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
TheStoreType.getTypeForEVT(*DAG.getContext()));
}
}
- if (!IsByVal)
- ++OIdx;
}
assert(StoreOperands.empty() && "Unfinished parameter store.");
- if (!IsByVal && VTs.size() > 0)
- --OIdx;
- ++ParamCount;
if (IsByVal && IsVAArg)
VAOffset += TypeSize;
}
@@ -1714,7 +1709,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
MaybeAlign retAlignment = std::nullopt;
// Handle Result
- if (Ins.size() > 0) {
+ if (!Ins.empty()) {
SmallVector<EVT, 16> resvtparts;
ComputeValueVTs(*this, DL, RetTy, resvtparts);
@@ -1724,47 +1719,42 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
if (!shouldPassAsArray(RetTy)) {
resultsz = promoteScalarArgumentSize(resultsz);
- SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
- DAG.getConstant(resultsz, dl, MVT::i32),
- DAG.getConstant(0, dl, MVT::i32), InGlue };
- Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
+ SDValue DeclareRetOps[] = {Chain, GetI32(1), GetI32(resultsz), GetI32(0),
+ InGlue};
+ Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, {MVT::Other, MVT::Glue},
DeclareRetOps);
InGlue = Chain.getValue(1);
} else {
retAlignment = getArgumentAlignment(CB, RetTy, 0, DL);
assert(retAlignment && "retAlignment is guaranteed to be set");
- SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- SDValue DeclareRetOps[] = {
- Chain, DAG.getConstant(retAlignment->value(), dl, MVT::i32),
- DAG.getConstant(resultsz / 8, dl, MVT::i32),
- DAG.getConstant(0, dl, MVT::i32), InGlue};
- Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
- DeclareRetOps);
+ SDValue DeclareRetOps[] = {Chain, GetI32(retAlignment->value()),
+ GetI32(resultsz / 8), GetI32(0), InGlue};
+ Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl,
+ {MVT::Other, MVT::Glue}, DeclareRetOps);
InGlue = Chain.getValue(1);
}
}
- bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
+ const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
// Set the size of the vararg param byte array if the callee is a variadic
// function and the variadic part is not empty.
if (HasVAArgs) {
- SDValue DeclareParamOps[] = {
- VADeclareParam.getOperand(0), VADeclareParam.getOperand(1),
- VADeclareParam.getOperand(2), DAG.getConstant(VAOffset, dl, MVT::i32),
- VADeclareParam.getOperand(4)};
+ SDValue DeclareParamOps[] = {VADeclareParam.getOperand(0),
+ VADeclareParam.getOperand(1),
+ VADeclareParam.getOperand(2), GetI32(VAOffset),
+ VADeclareParam.getOperand(4)};
DAG.MorphNodeTo(VADeclareParam.getNode(), VADeclareParam.getOpcode(),
VADeclareParam->getVTList(), DeclareParamOps);
}
// If the type of the callsite does not match that of the function, convert
// the callsite to an indirect call.
- bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func);
+ const bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func);
// Both indirect calls and libcalls have nullptr Func. In order to distinguish
// between them we must rely on the call site value which is valid for
// indirect calls but is always null for libcalls.
- bool isIndirectCall = (!Func && CB) || ConvertToIndirectCall;
+ const bool IsIndirectCall = (!Func && CB) || ConvertToIndirectCall;
if (isa<ExternalSymbolSDNode>(Callee)) {
Function* CalleeFunc = nullptr;
@@ -1778,7 +1768,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
CalleeFunc->addFnAttr("nvptx-libcall-callee", "true");
}
- if (isIndirectCall) {
+ if (IsIndirectCall) {
// This is indirect function call case : PTX requires a prototype of the
// form
// proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
@@ -1786,9 +1776,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// instruction.
// The prototype is embedded in a string and put as the operand for a
// CallPrototype SDNode which will print out to the value of the string.
- SDVTList ProtoVTs = DAG.getVTList(MVT::Other, MVT::Glue);
std::string Proto = getPrototype(
- DL, RetTy, Args, Outs, retAlignment,
+ DL, RetTy, Args, CLI.Outs, retAlignment,
HasVAArgs
? std::optional<std::pair<unsigned, const APInt &>>(std::make_pair(
CLI.NumFixedArgs, VADeclareParam->getConstantOperandAPInt(1)))
@@ -1800,20 +1789,19 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
DAG.getTargetExternalSymbol(ProtoStr, MVT::i32),
InGlue,
};
- Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, ProtoVTs, ProtoOps);
+ Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, {MVT::Other, MVT::Glue},
+ ProtoOps);
InGlue = Chain.getValue(1);
}
// Op to just print "call"
- SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- SDValue PrintCallOps[] = {
- Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, dl, MVT::i32), InGlue
- };
+ SDValue PrintCallOps[] = {Chain, GetI32(Ins.empty() ? 0 : 1), InGlue};
// We model convergent calls as separate opcodes.
- unsigned Opcode = isIndirectCall ? NVPTXISD::PrintCall : NVPTXISD::PrintCallUni;
+ unsigned Opcode =
+ IsIndirectCall ? NVPTXISD::PrintCall : NVPTXISD::PrintCallUni;
if (CLI.IsConvergent)
Opcode = Opcode == NVPTXISD::PrintCallUni ? NVPTXISD::PrintConvergentCallUni
: NVPTXISD::PrintConvergentCall;
- Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps);
+ Chain = DAG.getNode(Opcode, dl, {MVT::Other, MVT::Glue}, PrintCallOps);
InGlue = Chain.getValue(1);
if (ConvertToIndirectCall) {
@@ -1829,43 +1817,34 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
}
// Ops to print out the function name
- SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
SDValue CallVoidOps[] = { Chain, Callee, InGlue };
- Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps);
+ Chain =
+ DAG.getNode(NVPTXISD::CallVoid, dl, {MVT::Other, MVT::Glue}, CallVoidOps);
InGlue = Chain.getValue(1);
// Ops to print out the param list
- SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
SDValue CallArgBeginOps[] = { Chain, InGlue };
- Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
+ Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, {MVT::Other, MVT::Glue},
CallArgBeginOps);
InGlue = Chain.getValue(1);
- for (unsigned i = 0, e = std::min(CLI.NumFixedArgs + 1, ParamCount); i != e;
- ++i) {
- unsigned opcode;
- if (i == (e - 1))
- opcode = NVPTXISD::LastCallArg;
- else
- opcode = NVPTXISD::CallArg;
- SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- SDValue CallArgOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
- DAG.getConstant(i, dl, MVT::i32), InGlue };
- Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps);
+ const unsigned E = std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size());
+ for (const unsigned I : llvm::seq(E)) {
+ const unsigned Opcode =
+ I == (E - 1) ? NVPTXISD::LastCallArg : NVPTXISD::CallArg;
+ SDValue CallArgOps[] = {Chain, GetI32(1), GetI32(I), InGlue};
+ Chain = DAG.getNode(Opcode, dl, {MVT::Other, MVT::Glue}, CallArgOps);
InGlue = Chain.getValue(1);
}
- SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- SDValue CallArgEndOps[] = { Chain,
- DAG.getConstant(isIndirectCall ? 0 : 1, dl, MVT::i32),
- InGlue };
- Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps);
+ SDValue CallArgEndOps[] = {Chain, GetI32(IsIndirectCall ? 0 : 1), InGlue};
+ Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, {MVT::Other, MVT::Glue},
+ CallArgEndOps);
InGlue = Chain.getValue(1);
- if (isIndirectCall) {
- SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- SDValue PrototypeOps[] = {
- Chain, DAG.getConstant(UniqueCallSite, dl, MVT::i32), InGlue};
- Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps);
+ if (IsIndirectCall) {
+ SDValue PrototypeOps[] = {Chain, GetI32(UniqueCallSite), InGlue};
+ Chain = DAG.getNode(NVPTXISD::Prototype, dl, {MVT::Other, MVT::Glue},
+ PrototypeOps);
InGlue = Chain.getValue(1);
}
@@ -1881,7 +1860,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SmallVector<SDValue, 16> TempProxyRegOps;
// Generate loads from param memory/moves from registers for result
- if (Ins.size() > 0) {
+ if (!Ins.empty()) {
SmallVector<EVT, 16> VTs;
SmallVector<uint64_t, 16> Offsets;
ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
@@ -1896,60 +1875,57 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
// 32-bits are sign extended or zero extended, depending on whether
// they are signed or unsigned types.
- bool ExtendIntegerRetVal =
+ const bool ExtendIntegerRetVal =
RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
- for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
- bool needTruncate = false;
- EVT TheLoadType = VTs[i];
- EVT EltType = Ins[i].VT;
- Align EltAlign = commonAlignment(RetAlign, Offsets[i]);
- MVT PromotedVT;
-
- if (PromoteScalarIntegerPTX(TheLoadType, &PromotedVT)) {
- TheLoadType = EVT(PromotedVT);
- EltType = EVT(PromotedVT);
- needTruncate = true;
+ for (const unsigned I : llvm::seq(VTs.size())) {
+ bool NeedTruncate = false;
+ EVT TheLoadType = VTs[I];
+ EVT EltType = Ins[I].VT;
+ Align EltAlign = commonAlignment(RetAlign, Offsets[I]);
+
+ if (auto PromotedVT = PromoteScalarIntegerPTX(TheLoadType)) {
+ TheLoadType = *PromotedVT;
+ EltType = *PromotedVT;
+ NeedTruncate = true;
}
if (ExtendIntegerRetVal) {
TheLoadType = MVT::i32;
EltType = MVT::i32;
- needTruncate = true;
+ NeedTruncate = true;
} else if (TheLoadType.getSizeInBits() < 16) {
- if (VTs[i].isInteger())
- needTruncate = true;
+ if (VTs[I].isInteger())
+ NeedTruncate = true;
EltType = MVT::i16;
}
// If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
// scalar load. In such cases, fall back to byte loads.
- if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType() &&
- EltAlign < DL.getABITypeAlign(
- TheLoadType.getTypeForEVT(*DAG.getContext()))) {
+ if (VectorInfo[I] == PVF_SCALAR && RetTy->isAggregateType() &&
+ EltAlign < DAG.getEVTAlign(TheLoadType)) {
assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
SDValue Ret = LowerUnalignedLoadRetParam(
- DAG, Chain, Offsets[i], TheLoadType, InGlue, TempProxyRegOps, dl);
+ DAG, Chain, Offsets[I], TheLoadType, InGlue, TempProxyRegOps, dl);
ProxyRegOps.push_back(SDValue());
ProxyRegTruncates.push_back(std::optional<MVT>());
- RetElts.resize(i);
+ RetElts.resize(I);
RetElts.push_back(Ret);
continue;
}
// Record index of the very first element of the vector.
- if (VectorInfo[i] & PVF_FIRST) {
+ if (VectorInfo[I] & PVF_FIRST) {
assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
- VecIdx = i;
+ VecIdx = I;
}
LoadVTs.push_back(EltType);
- if (VectorInfo[i] & PVF_LAST) {
- unsigned NumElts = LoadVTs.size();
- LoadVTs.push_back(MVT::Other);
- LoadVTs.push_back(MVT::Glue);
+ if (VectorInfo[I] & PVF_LAST) {
+ const unsigned NumElts = LoadVTs.size();
+ LoadVTs.append({MVT::Other, MVT::Glue});
NVPTXISD::NodeType Op;
switch (NumElts) {
case 1:
@@ -1965,21 +1941,20 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
llvm_unreachable("Invalid vector info.");
}
- SDValue LoadOperands[] = {
- Chain, DAG.getConstant(1, dl, MVT::i32),
- DAG.getConstant(Offsets[VecIdx], dl, MVT::i32), InGlue};
+ SDValue LoadOperands[] = {Chain, GetI32(1), GetI32(Offsets[VecIdx]),
+ InGlue};
SDValue RetVal = DAG.getMemIntrinsicNode(
Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType,
MachinePointerInfo(), EltAlign,
MachineMemOperand::MOLoad);
- for (unsigned j = 0; j < NumElts; ++j) {
- ProxyRegOps.push_back(RetVal.getValue(j));
+ for (const unsigned J : llvm::seq(NumElts)) {
+ ProxyRegOps.push_back(RetVal.getValue(J));
- if (needTruncate)
- ProxyRegTruncates.push_back(std::optional<MVT>(Ins[VecIdx + j].VT));
+ if (NeedTruncate)
+ ProxyRegTruncates.push_back(Ins[VecIdx + J].VT);
else
- ProxyRegTruncates.push_back(std::optional<MVT>());
+ ProxyRegTruncates.push_back(std::nullopt);
}
Chain = RetVal.getValue(NumElts);
@@ -1999,33 +1974,31 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// Append ProxyReg instructions to the chain to make sure that `callseq_end`
// will not get lost. Otherwise, during libcalls expansion, the nodes can become
// dangling.
- for (unsigned i = 0; i < ProxyRegOps.size(); ++i) {
- if (i < RetElts.size() && RetElts[i]) {
- InVals.push_back(RetElts[i]);
+ for (const unsigned I : llvm::seq(ProxyRegOps.size())) {
+ if (I < RetElts.size() && RetElts[I]) {
+ InVals.push_back(RetElts[I]);
continue;
}
SDValue Ret = DAG.getNode(
- NVPTXISD::ProxyReg, dl,
- DAG.getVTList(ProxyRegOps[i].getSimpleValueType(), MVT::Other, MVT::Glue),
- { Chain, ProxyRegOps[i], InGlue }
- );
+ NVPTXISD::ProxyReg, dl,
+ {ProxyRegOps[I].getSimpleValueType(), MVT::Other, MVT::Glue},
+ {Chain, ProxyRegOps[I], InGlue});
Chain = Ret.getValue(1);
InGlue = Ret.getValue(2);
- if (ProxyRegTruncates[i]) {
- Ret = DAG.getNode(ISD::TRUNCATE, dl, *ProxyRegTruncates[i], Ret);
+ if (ProxyRegTruncates[I]) {
+ Ret = DAG.getNode(ISD::TRUNCATE, dl, *ProxyRegTruncates[I], Ret);
}
InVals.push_back(Ret);
}
for (SDValue &T : TempProxyRegOps) {
- SDValue Repl = DAG.getNode(
- NVPTXISD::ProxyReg, dl,
- DAG.getVTList(T.getSimpleValueType(), MVT::Other, MVT::Glue),
- {Chain, T.getOperand(0), InGlue});
+ SDValue Repl = DAG.getNode(NVPTXISD::ProxyReg, dl,
+ {T.getSimpleValueType(), MVT::Other, MVT::Glue},
+ {Chain, T.getOperand(0), InGlue});
DAG.ReplaceAllUsesWith(T, Repl);
DAG.RemoveDeadNode(T.getNode());
@@ -3451,29 +3424,29 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
// That's the last element of this store op.
if (VectorInfo[PartI] & PVF_LAST) {
const unsigned NumElts = PartI - VecIdx + 1;
- EVT EltVT = VTs[PartI];
- // i1 is loaded/stored as i8.
- EVT LoadVT = EltVT;
- if (EltVT == MVT::i1)
- LoadVT = MVT::i8;
- else if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
+ const EVT EltVT = VTs[PartI];
+ const EVT LoadVT = [&]() -> EVT {
+ // i1 is loaded/stored as i8.
+ if (EltVT == MVT::i1)
+ return MVT::i8;
// getLoad needs a vector type, but it can't handle
// vectors which contain v2f16 or v2bf16 elements. So we must load
// using i32 here and then bitcast back.
- LoadVT = MVT::i32;
+ if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
+ return MVT::i32;
+ return EltVT;
+ }();
- EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
- SDValue VecAddr =
- DAG.getNode(ISD::ADD, dl, PtrVT, ArgSymbol,
- DAG.getConstant(Offsets[VecIdx], dl, PtrVT));
+ const EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
+ SDValue VecAddr = DAG.getObjectPtrOffset(
+ dl, ArgSymbol, TypeSize::getFixed(Offsets[VecIdx]));
const MaybeAlign PartAlign = [&]() -> MaybeAlign {
if (aggregateIsPacked)
return Align(1);
if (NumElts != 1)
return std::nullopt;
- Align PartAlign =
- DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext()));
+ Align PartAlign = DAG.getEVTAlign(EltVT);
return commonAlignment(PartAlign, Offsets[PartI]);
}();
SDValue P =
@@ -3486,26 +3459,20 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
for (const unsigned J : llvm::seq(NumElts)) {
SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
DAG.getIntPtrConstant(J, dl));
- // We've loaded i1 as an i8 and now must truncate it back to i1
- if (EltVT == MVT::i1)
- Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
- // v2f16 was loaded as an i32. Now we must bitcast it back.
- Elt = DAG.getBitcast(EltVT, Elt);
-
- // If a promoted integer type is used, truncate down to the original
- MVT PromotedVT;
- if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) {
- Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
- }
- // Extend the element if necessary (e.g. an i8 is loaded
+ // Extend or truncate the element if necessary (e.g. an i8 is loaded
// into an i16 register)
- if (ArgIns[PartI].VT.getFixedSizeInBits() !=
- LoadVT.getFixedSizeInBits()) {
- assert(ArgIns[PartI].VT.isInteger() && LoadVT.isInteger() &&
+ const EVT ExpactedVT = ArgIns[PartI].VT;
+ if (ExpactedVT.getFixedSizeInBits() !=
+ Elt.getValueType().getFixedSizeInBits()) {
+ assert(ExpactedVT.isScalarInteger() &&
+ Elt.getValueType().isScalarInteger() &&
"Non-integer argument type size mismatch");
Elt = DAG.getExtOrTrunc(ArgIns[PartI].Flags.isSExt(), Elt, dl,
- ArgIns[PartI].VT);
+ ExpactedVT);
+ } else {
+ // v2f16 was loaded as an i32. Now we must bitcast it back.
+ Elt = DAG.getBitcast(EltVT, Elt);
}
InVals.push_back(Elt);
}
@@ -3561,47 +3528,37 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
Type *RetTy = MF.getFunction().getReturnType();
const DataLayout &DL = DAG.getDataLayout();
- SmallVector<SDValue, 16> PromotedOutVals;
SmallVector<EVT, 16> VTs;
SmallVector<uint64_t, 16> Offsets;
ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
- for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
- SDValue PromotedOutVal = OutVals[i];
- MVT PromotedVT;
- if (PromoteScalarIntegerPTX(VTs[i], &PromotedVT)) {
- VTs[i] = EVT(PromotedVT);
- }
- if (PromoteScalarIntegerPTX(PromotedOutVal.getValueType(), &PromotedVT)) {
- llvm::ISD::NodeType Ext =
- Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
- PromotedOutVal = DAG.getNode(Ext, dl, PromotedVT, PromotedOutVal);
- }
- PromotedOutVals.push_back(PromotedOutVal);
- }
+ for (const unsigned I : llvm::seq(VTs.size()))
+ if (auto PromotedVT = PromoteScalarIntegerPTX(VTs[I]))
+ VTs[I] = *PromotedVT;
auto VectorInfo = VectorizePTXValueVTs(
VTs, Offsets,
- RetTy->isSized() ? getFunctionParamOptimizedAlign(&F, RetTy, DL)
+ !RetTy->isVoidTy() ? getFunctionParamOptimizedAlign(&F, RetTy, DL)
: Align(1));
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
// 32-bits are sign extended or zero extended, depending on whether
// they are signed or unsigned types.
- bool ExtendIntegerRetVal =
+ const bool ExtendIntegerRetVal =
RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
SmallVector<SDValue, 6> StoreOperands;
- for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
- SDValue OutVal = OutVals[i];
- SDValue RetVal = PromotedOutVals[i];
+ for (const unsigned I : llvm::seq(VTs.size())) {
+ SDValue RetVal = OutVals[I];
+ assert(!PromoteScalarIntegerPTX(RetVal.getValueType()) &&
+ "OutVal type should always be legal");
if (ExtendIntegerRetVal) {
- RetVal = DAG.getNode(Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND
+ RetVal = DAG.getNode(Outs[I].Flags.isSExt() ? ISD::SIGN_EXTEND
: ISD::ZERO_EXTEND,
dl, MVT::i32, RetVal);
- } else if (OutVal.getValueSizeInBits() < 16) {
+ } else if (RetVal.getValueSizeInBits() < 16) {
// Use 16-bit registers for small load-stores as it's the
// smallest general purpose register size supported by NVPTX.
RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal);
@@ -3609,15 +3566,14 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
// If we have a PVF_SCALAR entry, it may not even be sufficiently aligned
// for a scalar store. In such cases, fall back to byte stores.
- if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType()) {
- EVT ElementType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
- Align ElementTypeAlign =
- DL.getABITypeAlign(ElementType.getTypeForEVT(RetTy->getContext()));
- Align ElementAlign =
- commonAlignment(DL.getABITypeAlign(RetTy), Offsets[i]);
+ if (VectorInfo[I] == PVF_SCALAR && RetTy->isAggregateType()) {
+ const EVT ElementType = ExtendIntegerRetVal ? MVT::i32 : VTs[I];
+ const Align ElementTypeAlign = DAG.getEVTAlign(ElementType);
+ const Align ElementAlign =
+ commonAlignment(DL.getABITypeAlign(RetTy), Offsets[I]);
if (ElementAlign < ElementTypeAlign) {
assert(StoreOperands.empty() && "Orphaned operand list.");
- Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[i], ElementType,
+ Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[I], ElementType,
RetVal, dl);
// The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
@@ -3627,17 +3583,16 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
}
// New load/store. Record chain and offset operands.
- if (VectorInfo[i] & PVF_FIRST) {
+ if (VectorInfo[I] & PVF_FIRST) {
assert(StoreOperands.empty() && "Orphaned operand list.");
- StoreOperands.push_back(Chain);
- StoreOperands.push_back(DAG.getConstant(Offsets[i], dl, MVT::i32));
+ StoreOperands.append({Chain, DAG.getConstant(Offsets[I], dl, MVT::i32)});
}
// Record the value to return.
StoreOperands.push_back(RetVal);
// That's the last element of this store op.
- if (VectorInfo[i] & PVF_LAST) {
+ if (VectorInfo[I] & PVF_LAST) {
NVPTXISD::NodeType Op;
unsigned NumElts = StoreOperands.size() - 2;
switch (NumElts) {
@@ -3656,7 +3611,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
// Adjust type of load/store op if we've extended the scalar
// return value.
- EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
+ EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[I];
Chain = DAG.getMemIntrinsicNode(
Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType,
MachinePointerInfo(), Align(1), MachineMemOperand::MOStore);
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 043da14bcb236..77be311f4e496 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -2050,8 +2050,7 @@ def SDTDeclareScalarParamProfile :
def SDTLoadParamProfile : SDTypeProfile<1, 2, [SDTCisInt<1>, SDTCisInt<2>]>;
def SDTLoadParamV2Profile : SDTypeProfile<2, 2, [SDTCisSameAs<0, 1>, SDTCisInt<2>, SDTCisInt<3>]>;
def SDTLoadParamV4Profile : SDTypeProfile<4, 2, [SDTCisInt<4>, SDTCisInt<5>]>;
-def SDTPrintCallProfile : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
-def SDTPrintCallUniProfile : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
+def SDTPrintCallProfile : SDTypeProfile<0, 1, [SDTCisVT<0, i32>]>;
def SDTStoreParamProfile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>;
def SDTStoreParamV2Profile : SDTypeProfile<0, 4, [SDTCisInt<0>, SDTCisInt<1>]>;
def SDTStoreParamV4Profile : SDTypeProfile<0, 6, [SDTCisInt<0>, SDTCisInt<1>]>;
@@ -2095,10 +2094,10 @@ def PrintConvergentCall :
SDNode<"NVPTXISD::PrintConvergentCall", SDTPrintCallProfile,
[SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
def PrintCallUni :
- SDNode<"NVPTXISD::PrintCallUni", SDTPrintCallUniProfile,
+ SDNode<"NVPTXISD::PrintCallUni", SDTPrintCallProfile,
[SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
def PrintConvergentCallUni :
- SDNode<"NVPTXISD::PrintConvergentCallUni", SDTPrintCallUniProfile,
+ SDNode<"NVPTXISD::PrintConvergentCallUni", SDTPrintCallProfile,
[SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
def StoreParam :
SDNode<"NVPTXISD::StoreParam", SDTStoreParamProfile,
@@ -2247,31 +2246,9 @@ let mayStore = true in {
let isCall=1 in {
multiclass CALL<string OpcStr, SDNode OpNode> {
def PrintCallNoRetInst : NVPTXInst<(outs), (ins),
- !strconcat(OpcStr, " "), [(OpNode (i32 0))]>;
+ OpcStr # " ", [(OpNode 0)]>;
def PrintCallRetInst1 : NVPTXInst<(outs), (ins),
- !strconcat(OpcStr, " (retval0), "), [(OpNode (i32 1))]>;
- def PrintCallRetInst2 : NVPTXInst<(outs), (ins),
- !strconcat(OpcStr, " (retval0, retval1), "), [(OpNode (i32 2))]>;
- def PrintCallRetInst3 : NVPTXInst<(outs), (ins),
- !strconcat(OpcStr, " (retval0, retval1, retval2), "), [(OpNode (i32 3))]>;
- def PrintCallRetInst4 : NVPTXInst<(outs), (ins),
- !strconcat(OpcStr, " (retval0, retval1, retval2, retval3), "),
- [(OpNode (i32 4))]>;
- def PrintCallRetInst5 : NVPTXInst<(outs), (ins),
- !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4), "),
- [(OpNode (i32 5))]>;
- def PrintCallRetInst6 : NVPTXInst<(outs), (ins),
- !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, "
- "retval5), "),
- [(OpNode (i32 6))]>;
- def PrintCallRetInst7 : NVPTXInst<(outs), (ins),
- !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, "
- "retval5, retval6), "),
- [(OpNode (i32 7))]>;
- def PrintCallRetInst8 : NVPTXInst<(outs), (ins),
- !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, "
- "retval5, retval6, retval7), "),
- [(OpNode (i32 8))]>;
+ OpcStr # " (retval0), ", [(OpNode 1)]>;
}
}
>From baf403827be872dda63ab49ab0ac91470ab4397f Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Sun, 27 Apr 2025 17:57:23 +0000
Subject: [PATCH 2/3] [NVPTX][NFC] Refactor parameter vectorization
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 528 +++++++++-----------
1 file changed, 237 insertions(+), 291 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index b287822e61db9..b8c419bee53ae 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -428,16 +428,6 @@ static unsigned CanMergeParamLoadStoresStartingAt(
return NumElts;
}
-// Flags for tracking per-element vectorization state of loads/stores
-// of a flattened function parameter or return value.
-enum ParamVectorizationFlags {
- PVF_INNER = 0x0, // Middle elements of a vector.
- PVF_FIRST = 0x1, // First element of the vector.
- PVF_LAST = 0x2, // Last element of the vector.
- // Scalar is effectively a 1-element vector.
- PVF_SCALAR = PVF_FIRST | PVF_LAST
-};
-
// Computes whether and how we can vectorize the loads/stores of a
// flattened function parameter or return value.
//
@@ -446,52 +436,39 @@ enum ParamVectorizationFlags {
// of the same size as ValueVTs indicating how each piece should be
// loaded/stored (i.e. as a scalar, or as part of a vector
// load/store).
-static SmallVector<ParamVectorizationFlags, 16>
+static SmallVector<unsigned, 16>
VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
const SmallVectorImpl<uint64_t> &Offsets,
Align ParamAlignment, bool IsVAArg = false) {
// Set vector size to match ValueVTs and mark all elements as
// scalars by default.
- SmallVector<ParamVectorizationFlags, 16> VectorInfo;
- VectorInfo.assign(ValueVTs.size(), PVF_SCALAR);
+ SmallVector<unsigned, 16> VectorInfo;
- if (IsVAArg)
+ if (IsVAArg) {
+ VectorInfo.assign(ValueVTs.size(), 1);
return VectorInfo;
+ }
- // Check what we can vectorize using 128/64/32-bit accesses.
- for (int I = 0, E = ValueVTs.size(); I != E; ++I) {
- // Skip elements we've already processed.
- assert(VectorInfo[I] == PVF_SCALAR && "Unexpected vector info state.");
- for (unsigned AccessSize : {16, 8, 4, 2}) {
- unsigned NumElts = CanMergeParamLoadStoresStartingAt(
+ const auto GetNumElts = [&](unsigned I) -> unsigned {
+ for (const unsigned AccessSize : {16, 8, 4, 2}) {
+ const unsigned NumElts = CanMergeParamLoadStoresStartingAt(
I, AccessSize, ValueVTs, Offsets, ParamAlignment);
- // Mark vectorized elements.
- switch (NumElts) {
- default:
- llvm_unreachable("Unexpected return value");
- case 1:
- // Can't vectorize using this size, try next smaller size.
- continue;
- case 2:
- assert(I + 1 < E && "Not enough elements.");
- VectorInfo[I] = PVF_FIRST;
- VectorInfo[I + 1] = PVF_LAST;
- I += 1;
- break;
- case 4:
- assert(I + 3 < E && "Not enough elements.");
- VectorInfo[I] = PVF_FIRST;
- VectorInfo[I + 1] = PVF_INNER;
- VectorInfo[I + 2] = PVF_INNER;
- VectorInfo[I + 3] = PVF_LAST;
- I += 3;
- break;
- }
- // Break out of the inner loop because we've already succeeded
- // using largest possible AccessSize.
- break;
+ assert((NumElts == 1 || NumElts == 2 || NumElts == 4) &&
+ "Unexpected vectorization size");
+ if (NumElts != 1)
+ return NumElts;
}
+ return 1;
+ };
+
+ // Check what we can vectorize using 128/64/32-bit accesses.
+ for (unsigned I = 0, E = ValueVTs.size(); I != E;) {
+ const unsigned NumElts = GetNumElts(I);
+ VectorInfo.push_back(NumElts);
+ I += NumElts;
}
+ assert(std::accumulate(VectorInfo.begin(), VectorInfo.end(), 0u) ==
+ ValueVTs.size());
return VectorInfo;
}
@@ -1513,7 +1490,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
AllOuts = AllOuts.drop_front(ArgOuts.size());
AllOutVals = AllOutVals.drop_front(ArgOuts.size());
- const bool IsVAArg = (I >= CLI.NumFixedArgs);
+ const bool IsVAArg = (I >= FirstVAArg);
const bool IsByVal = Arg.IsByVal;
SmallVector<EVT, 16> VTs;
@@ -1586,32 +1563,22 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const bool ExtendIntegerParam =
Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
- auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
- SmallVector<SDValue, 6> StoreOperands;
- for (const unsigned J : llvm::seq(VTs.size())) {
- EVT EltVT = VTs[J];
- const int CurOffset = Offsets[J];
- MaybeAlign PartAlign;
- if (NeedAlign)
- PartAlign = commonAlignment(ArgAlign, CurOffset);
-
- if (auto PromotedVT = PromoteScalarIntegerPTX(EltVT))
- EltVT = *PromotedVT;
-
+ const auto GetStoredValue = [&](const unsigned I, EVT EltVT,
+ MaybeAlign PartAlign) {
SDValue StVal;
if (IsByVal) {
SDValue Ptr = ArgOutVals[0];
auto MPI = refinePtrAS(Ptr, DAG, DL, *this);
SDValue SrcAddr =
- DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(CurOffset));
+ DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(Offsets[I]));
StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign);
} else {
- StVal = ArgOutVals[J];
+ StVal = ArgOutVals[I];
if (auto PromotedVT = PromoteScalarIntegerPTX(StVal.getValueType())) {
llvm::ISD::NodeType Ext =
- ArgOuts[J].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+ ArgOuts[I].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
StVal = DAG.getNode(Ext, dl, *PromotedVT, StVal);
}
}
@@ -1619,7 +1586,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
if (ExtendIntegerParam) {
assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
// zext/sext to i32
- StVal = DAG.getNode(ArgOuts[J].Flags.isSExt() ? ISD::SIGN_EXTEND
+ StVal = DAG.getNode(ArgOuts[I].Flags.isSExt() ? ISD::SIGN_EXTEND
: ISD::ZERO_EXTEND,
dl, MVT::i32, StVal);
} else if (EltVT.getSizeInBits() < 16) {
@@ -1627,81 +1594,91 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// smallest general purpose register size supported by NVPTX.
StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
}
+ return StVal;
+ };
+
+ const auto VectorInfo =
+ VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
+
+ unsigned J = 0;
+ for (const unsigned NumElts : VectorInfo) {
+ const int CurOffset = Offsets[J];
+ EVT EltVT = VTs[J];
+ MaybeAlign PartAlign;
+ if (NeedAlign)
+ PartAlign = commonAlignment(ArgAlign, CurOffset);
+
+ if (auto PromotedVT = PromoteScalarIntegerPTX(EltVT))
+ EltVT = *PromotedVT;
// If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
// scalar store. In such cases, fall back to byte stores.
- if (VectorInfo[J] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() &&
+ if (NumElts == 1 && !IsVAArg && PartAlign.has_value() &&
PartAlign.value() < DAG.getEVTAlign(EltVT)) {
- assert(StoreOperands.empty() && "Unfinished preceeding store.");
- Chain = LowerUnalignedStoreParam(
- DAG, Chain, IsByVal ? CurOffset + VAOffset : CurOffset, EltVT,
- StVal, InGlue, I, dl);
+
+ SDValue StVal = GetStoredValue(J, EltVT, PartAlign);
+ Chain = LowerUnalignedStoreParam(DAG, Chain,
+ CurOffset + (IsByVal ? VAOffset : 0),
+ EltVT, StVal, InGlue, I, dl);
// LowerUnalignedStoreParam took care of inserting the necessary nodes
// into the SDAG, so just move on to the next element.
+ J++;
continue;
}
- // New store.
- if (VectorInfo[J] & PVF_FIRST) {
- if (!IsByVal && IsVAArg)
- // Align each part of the variadic argument to their type.
- VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
-
- assert(StoreOperands.empty() && "Unfinished preceding store.");
- StoreOperands.append(
- {Chain, GetI32(IsVAArg ? FirstVAArg : I),
- GetI32(IsByVal ? CurOffset + VAOffset
- : (IsVAArg ? VAOffset : CurOffset))});
- }
+ if (IsVAArg && !IsByVal)
+ // Align each part of the variadic argument to their type.
+ VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
- // Record the value to store.
- StoreOperands.push_back(StVal);
+ assert((IsVAArg || VAOffset == 0) &&
+ "VAOffset must be 0 for non-VA args");
+ SmallVector<SDValue, 6> StoreOperands{
+ Chain, GetI32(IsVAArg ? FirstVAArg : I),
+ GetI32(IsByVal ? CurOffset + VAOffset
+ : (IsVAArg ? VAOffset : CurOffset))};
- if (VectorInfo[J] & PVF_LAST) {
- const unsigned NumElts = StoreOperands.size() - 3;
- NVPTXISD::NodeType Op;
- switch (NumElts) {
- case 1:
- Op = NVPTXISD::StoreParam;
- break;
- case 2:
- Op = NVPTXISD::StoreParamV2;
- break;
- case 4:
- Op = NVPTXISD::StoreParamV4;
- break;
- default:
- llvm_unreachable("Invalid vector info.");
- }
+ // Record the values to store.
+ for (const unsigned K : llvm::seq(NumElts))
+ StoreOperands.push_back(GetStoredValue(J + K, EltVT, PartAlign));
+ StoreOperands.push_back(InGlue);
- StoreOperands.push_back(InGlue);
-
- // Adjust type of the store op if we've extended the scalar
- // return value.
- EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
-
- Chain = DAG.getMemIntrinsicNode(
- Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands,
- TheStoreType, MachinePointerInfo(), PartAlign,
- MachineMemOperand::MOStore);
- InGlue = Chain.getValue(1);
+ NVPTXISD::NodeType Op;
+ switch (NumElts) {
+ case 1:
+ Op = NVPTXISD::StoreParam;
+ break;
+ case 2:
+ Op = NVPTXISD::StoreParamV2;
+ break;
+ case 4:
+ Op = NVPTXISD::StoreParamV4;
+ break;
+ default:
+ llvm_unreachable("Invalid vector info.");
+ }
+ // Adjust type of the store op if we've extended the scalar
+ // return value.
+ EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
- // Cleanup.
- StoreOperands.clear();
+ Chain = DAG.getMemIntrinsicNode(
+ Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands,
+ TheStoreType, MachinePointerInfo(), PartAlign,
+ MachineMemOperand::MOStore);
+ InGlue = Chain.getValue(1);
- // TODO: We may need to support vector types that can be passed
- // as scalars in variadic arguments.
- if (!IsByVal && IsVAArg) {
- assert(NumElts == 1 &&
- "Vectorization is expected to be disabled for variadics.");
- VAOffset += DL.getTypeAllocSize(
- TheStoreType.getTypeForEVT(*DAG.getContext()));
- }
+ // TODO: We may need to support vector types that can be passed
+ // as scalars in variadic arguments.
+ if (IsVAArg && !IsByVal) {
+ assert(NumElts == 1 &&
+ "Vectorization is expected to be disabled for variadics.");
+ VAOffset +=
+ DL.getTypeAllocSize(TheStoreType.getTypeForEVT(*DAG.getContext()));
}
+
+ J += NumElts;
}
- assert(StoreOperands.empty() && "Unfinished parameter store.");
- if (IsByVal && IsVAArg)
+ if (IsVAArg && IsByVal)
VAOffset += TypeSize;
}
@@ -1716,10 +1693,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// Declare
// .param .align N .b8 retval0[<size-in-bytes>], or
// .param .b<size-in-bits> retval0
- unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
+ const unsigned ResultSize = DL.getTypeAllocSizeInBits(RetTy);
if (!shouldPassAsArray(RetTy)) {
- resultsz = promoteScalarArgumentSize(resultsz);
- SDValue DeclareRetOps[] = {Chain, GetI32(1), GetI32(resultsz), GetI32(0),
+ const unsigned PromotedResultSize = promoteScalarArgumentSize(ResultSize);
+ SDValue DeclareRetOps[] = {Chain, GetI32(1), GetI32(PromotedResultSize), GetI32(0),
InGlue};
Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, {MVT::Other, MVT::Glue},
DeclareRetOps);
@@ -1728,7 +1705,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
retAlignment = getArgumentAlignment(CB, RetTy, 0, DL);
assert(retAlignment && "retAlignment is guaranteed to be set");
SDValue DeclareRetOps[] = {Chain, GetI32(retAlignment->value()),
- GetI32(resultsz / 8), GetI32(0), InGlue};
+ GetI32(ResultSize / 8), GetI32(0), InGlue};
Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl,
{MVT::Other, MVT::Glue}, DeclareRetOps);
InGlue = Chain.getValue(1);
@@ -1866,11 +1843,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
assert(VTs.size() == Ins.size() && "Bad value decomposition");
- Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
- auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
-
- SmallVector<EVT, 6> LoadVTs;
- int VecIdx = -1; // Index of the first element of the vector.
+ const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
// 32-bits are sign extended or zero extended, depending on whether
@@ -1878,11 +1851,13 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const bool ExtendIntegerRetVal =
RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
- for (const unsigned I : llvm::seq(VTs.size())) {
+ const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
+ unsigned I = 0;
+ for (const unsigned VectorizedSize : VectorInfo) {
bool NeedTruncate = false;
EVT TheLoadType = VTs[I];
EVT EltType = Ins[I].VT;
- Align EltAlign = commonAlignment(RetAlign, Offsets[I]);
+ const Align EltAlign = commonAlignment(RetAlign, Offsets[I]);
if (auto PromotedVT = PromoteScalarIntegerPTX(TheLoadType)) {
TheLoadType = *PromotedVT;
@@ -1895,16 +1870,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
EltType = MVT::i32;
NeedTruncate = true;
} else if (TheLoadType.getSizeInBits() < 16) {
- if (VTs[I].isInteger())
+ if (TheLoadType.isInteger())
NeedTruncate = true;
EltType = MVT::i16;
}
// If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
// scalar load. In such cases, fall back to byte loads.
- if (VectorInfo[I] == PVF_SCALAR && RetTy->isAggregateType() &&
+ if (VectorizedSize == 1 && RetTy->isAggregateType() &&
EltAlign < DAG.getEVTAlign(TheLoadType)) {
- assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
SDValue Ret = LowerUnalignedLoadRetParam(
DAG, Chain, Offsets[I], TheLoadType, InGlue, TempProxyRegOps, dl);
ProxyRegOps.push_back(SDValue());
@@ -1912,58 +1886,46 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
RetElts.resize(I);
RetElts.push_back(Ret);
+ I++;
continue;
}
- // Record index of the very first element of the vector.
- if (VectorInfo[I] & PVF_FIRST) {
- assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
- VecIdx = I;
- }
-
- LoadVTs.push_back(EltType);
+ SmallVector<EVT, 6> LoadVTs(VectorizedSize, EltType);
+ LoadVTs.append({MVT::Other, MVT::Glue});
- if (VectorInfo[I] & PVF_LAST) {
- const unsigned NumElts = LoadVTs.size();
- LoadVTs.append({MVT::Other, MVT::Glue});
- NVPTXISD::NodeType Op;
- switch (NumElts) {
- case 1:
- Op = NVPTXISD::LoadParam;
- break;
- case 2:
- Op = NVPTXISD::LoadParamV2;
- break;
- case 4:
- Op = NVPTXISD::LoadParamV4;
- break;
- default:
- llvm_unreachable("Invalid vector info.");
- }
+ NVPTXISD::NodeType Op;
+ switch (VectorizedSize) {
+ case 1:
+ Op = NVPTXISD::LoadParam;
+ break;
+ case 2:
+ Op = NVPTXISD::LoadParamV2;
+ break;
+ case 4:
+ Op = NVPTXISD::LoadParamV4;
+ break;
+ default:
+ llvm_unreachable("Invalid vector info.");
+ }
- SDValue LoadOperands[] = {Chain, GetI32(1), GetI32(Offsets[VecIdx]),
- InGlue};
- SDValue RetVal = DAG.getMemIntrinsicNode(
- Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType,
- MachinePointerInfo(), EltAlign,
- MachineMemOperand::MOLoad);
+ SDValue LoadOperands[] = {Chain, GetI32(1), GetI32(Offsets[I]), InGlue};
+ SDValue RetVal = DAG.getMemIntrinsicNode(
+ Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType,
+ MachinePointerInfo(), EltAlign, MachineMemOperand::MOLoad);
- for (const unsigned J : llvm::seq(NumElts)) {
- ProxyRegOps.push_back(RetVal.getValue(J));
+ for (const unsigned J : llvm::seq(VectorizedSize)) {
+ ProxyRegOps.push_back(RetVal.getValue(J));
- if (NeedTruncate)
- ProxyRegTruncates.push_back(Ins[VecIdx + J].VT);
- else
- ProxyRegTruncates.push_back(std::nullopt);
- }
+ if (NeedTruncate)
+ ProxyRegTruncates.push_back(Ins[I + J].VT);
+ else
+ ProxyRegTruncates.push_back(std::nullopt);
+ }
- Chain = RetVal.getValue(NumElts);
- InGlue = RetVal.getValue(NumElts + 1);
+ Chain = RetVal.getValue(VectorizedSize);
+ InGlue = RetVal.getValue(VectorizedSize + 1);
- // Cleanup
- VecIdx = -1;
- LoadVTs.clear();
- }
+ I += VectorizedSize;
}
}
@@ -3409,77 +3371,65 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
assert(VTs.size() == ArgIns.size() && "Size mismatch");
assert(VTs.size() == Offsets.size() && "Size mismatch");
- Align ArgAlign = getFunctionArgumentAlignment(
+ const Align ArgAlign = getFunctionArgumentAlignment(
F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
- auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
- assert(VectorInfo.size() == VTs.size() && "Size mismatch");
-
- int VecIdx = -1; // Index of the first element of the current vector.
- for (const unsigned PartI : llvm::seq(VTs.size())) {
- if (VectorInfo[PartI] & PVF_FIRST) {
- assert(VecIdx == -1 && "Orphaned vector.");
- VecIdx = PartI;
- }
- // That's the last element of this store op.
- if (VectorInfo[PartI] & PVF_LAST) {
- const unsigned NumElts = PartI - VecIdx + 1;
- const EVT EltVT = VTs[PartI];
- const EVT LoadVT = [&]() -> EVT {
- // i1 is loaded/stored as i8.
- if (EltVT == MVT::i1)
- return MVT::i8;
- // getLoad needs a vector type, but it can't handle
- // vectors which contain v2f16 or v2bf16 elements. So we must load
- // using i32 here and then bitcast back.
- if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
- return MVT::i32;
- return EltVT;
- }();
-
- const EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
- SDValue VecAddr = DAG.getObjectPtrOffset(
- dl, ArgSymbol, TypeSize::getFixed(Offsets[VecIdx]));
-
- const MaybeAlign PartAlign = [&]() -> MaybeAlign {
- if (aggregateIsPacked)
- return Align(1);
- if (NumElts != 1)
- return std::nullopt;
- Align PartAlign = DAG.getEVTAlign(EltVT);
- return commonAlignment(PartAlign, Offsets[PartI]);
- }();
- SDValue P =
- DAG.getLoad(VecVT, dl, Root, VecAddr,
- MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign,
- MachineMemOperand::MODereferenceable |
- MachineMemOperand::MOInvariant);
- if (P.getNode())
- P.getNode()->setIROrder(Arg.getArgNo() + 1);
- for (const unsigned J : llvm::seq(NumElts)) {
- SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
- DAG.getIntPtrConstant(J, dl));
-
- // Extend or truncate the element if necessary (e.g. an i8 is loaded
- // into an i16 register)
- const EVT ExpactedVT = ArgIns[PartI].VT;
- if (ExpactedVT.getFixedSizeInBits() !=
- Elt.getValueType().getFixedSizeInBits()) {
- assert(ExpactedVT.isScalarInteger() &&
- Elt.getValueType().isScalarInteger() &&
- "Non-integer argument type size mismatch");
- Elt = DAG.getExtOrTrunc(ArgIns[PartI].Flags.isSExt(), Elt, dl,
- ExpactedVT);
- } else {
- // v2f16 was loaded as an i32. Now we must bitcast it back.
- Elt = DAG.getBitcast(EltVT, Elt);
- }
- InVals.push_back(Elt);
+ const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
+ unsigned I = 0;
+ for (const unsigned NumElts : VectorInfo) {
+ const EVT EltVT = VTs[I];
+ const EVT LoadVT = [&]() -> EVT {
+ // i1 is loaded/stored as i8.
+ if (EltVT == MVT::i1)
+ return MVT::i8;
+ // getLoad needs a vector type, but it can't handle
+ // vectors which contain v2f16 or v2bf16 elements. So we must load
+ // using i32 here and then bitcast back.
+ if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
+ return MVT::i32;
+ return EltVT;
+ }();
+
+ const EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
+ SDValue VecAddr = DAG.getObjectPtrOffset(
+ dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
+
+ const MaybeAlign PartAlign = [&]() -> MaybeAlign {
+ if (aggregateIsPacked)
+ return Align(1);
+ if (NumElts != 1)
+ return std::nullopt;
+ Align PartAlign = DAG.getEVTAlign(EltVT);
+ return commonAlignment(PartAlign, Offsets[I]);
+ }();
+ SDValue P =
+ DAG.getLoad(VecVT, dl, Root, VecAddr,
+ MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign,
+ MachineMemOperand::MODereferenceable |
+ MachineMemOperand::MOInvariant);
+ if (P.getNode())
+ P.getNode()->setIROrder(Arg.getArgNo() + 1);
+ for (const unsigned J : llvm::seq(NumElts)) {
+ SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
+ DAG.getIntPtrConstant(J, dl));
+
+ // Extend or truncate the element if necessary (e.g. an i8 is loaded
+ // into an i16 register)
+ const EVT ExpactedVT = ArgIns[I + J].VT;
+ if (ExpactedVT.getFixedSizeInBits() !=
+ Elt.getValueType().getFixedSizeInBits()) {
+ assert(ExpactedVT.isScalarInteger() &&
+ Elt.getValueType().isScalarInteger() &&
+ "Non-integer argument type size mismatch");
+ Elt = DAG.getExtOrTrunc(ArgIns[I + J].Flags.isSExt(), Elt, dl,
+ ExpactedVT);
+ } else {
+ // v2f16 was loaded as an i32. Now we must bitcast it back.
+ Elt = DAG.getBitcast(EltVT, Elt);
}
-
- // Reset vector tracking state.
- VecIdx = -1;
+ InVals.push_back(Elt);
}
+ I += NumElts;
}
}
}
@@ -3527,6 +3477,11 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
const Function &F = MF.getFunction();
Type *RetTy = MF.getFunction().getReturnType();
+ if (RetTy->isVoidTy()) {
+ assert(OutVals.empty() && Outs.empty() && "Return value expected for void");
+ return DAG.getNode(NVPTXISD::RET_GLUE, dl, MVT::Other, Chain);
+ }
+
const DataLayout &DL = DAG.getDataLayout();
SmallVector<EVT, 16> VTs;
SmallVector<uint64_t, 16> Offsets;
@@ -3534,22 +3489,16 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
for (const unsigned I : llvm::seq(VTs.size()))
- if (auto PromotedVT = PromoteScalarIntegerPTX(VTs[I]))
+ if (const auto PromotedVT = PromoteScalarIntegerPTX(VTs[I]))
VTs[I] = *PromotedVT;
- auto VectorInfo = VectorizePTXValueVTs(
- VTs, Offsets,
- !RetTy->isVoidTy() ? getFunctionParamOptimizedAlign(&F, RetTy, DL)
- : Align(1));
-
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
// 32-bits are sign extended or zero extended, depending on whether
// they are signed or unsigned types.
const bool ExtendIntegerRetVal =
RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
- SmallVector<SDValue, 6> StoreOperands;
- for (const unsigned I : llvm::seq(VTs.size())) {
+ const auto GetRetVal = [&](unsigned I) -> SDValue {
SDValue RetVal = OutVals[I];
assert(!PromoteScalarIntegerPTX(RetVal.getValueType()) &&
"OutVal type should always be legal");
@@ -3563,61 +3512,58 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
// smallest general purpose register size supported by NVPTX.
RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal);
}
+ return RetVal;
+ };
- // If we have a PVF_SCALAR entry, it may not even be sufficiently aligned
- // for a scalar store. In such cases, fall back to byte stores.
- if (VectorInfo[I] == PVF_SCALAR && RetTy->isAggregateType()) {
- const EVT ElementType = ExtendIntegerRetVal ? MVT::i32 : VTs[I];
+ const auto VectorInfo = VectorizePTXValueVTs(
+ VTs, Offsets, getFunctionParamOptimizedAlign(&F, RetTy, DL));
+ unsigned I = 0;
+ for (const unsigned NumElts : VectorInfo) {
+ if (NumElts == 1 && RetTy->isAggregateType()) {
+ const EVT ElementType = VTs[I];
const Align ElementTypeAlign = DAG.getEVTAlign(ElementType);
const Align ElementAlign =
commonAlignment(DL.getABITypeAlign(RetTy), Offsets[I]);
if (ElementAlign < ElementTypeAlign) {
- assert(StoreOperands.empty() && "Orphaned operand list.");
Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[I], ElementType,
- RetVal, dl);
+ GetRetVal(I), dl);
// The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
// into the graph, so just move on to the next element.
+ I++;
continue;
}
}
- // New load/store. Record chain and offset operands.
- if (VectorInfo[I] & PVF_FIRST) {
- assert(StoreOperands.empty() && "Orphaned operand list.");
- StoreOperands.append({Chain, DAG.getConstant(Offsets[I], dl, MVT::i32)});
- }
+ SmallVector<SDValue, 6> StoreOperands{
+ Chain, DAG.getConstant(Offsets[I], dl, MVT::i32)};
- // Record the value to return.
- StoreOperands.push_back(RetVal);
+ for (const unsigned J : llvm::seq(NumElts))
+ StoreOperands.push_back(GetRetVal(I + J));
- // That's the last element of this store op.
- if (VectorInfo[I] & PVF_LAST) {
- NVPTXISD::NodeType Op;
- unsigned NumElts = StoreOperands.size() - 2;
- switch (NumElts) {
- case 1:
- Op = NVPTXISD::StoreRetval;
- break;
- case 2:
- Op = NVPTXISD::StoreRetvalV2;
- break;
- case 4:
- Op = NVPTXISD::StoreRetvalV4;
- break;
- default:
- llvm_unreachable("Invalid vector info.");
- }
-
- // Adjust type of load/store op if we've extended the scalar
- // return value.
- EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[I];
- Chain = DAG.getMemIntrinsicNode(
- Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType,
- MachinePointerInfo(), Align(1), MachineMemOperand::MOStore);
- // Cleanup vector state.
- StoreOperands.clear();
+ NVPTXISD::NodeType Op;
+ switch (NumElts) {
+ case 1:
+ Op = NVPTXISD::StoreRetval;
+ break;
+ case 2:
+ Op = NVPTXISD::StoreRetvalV2;
+ break;
+ case 4:
+ Op = NVPTXISD::StoreRetvalV4;
+ break;
+ default:
+ llvm_unreachable("Invalid vector info.");
}
+
+ // Adjust type of load/store op if we've extended the scalar
+ // return value.
+ EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[I];
+ Chain = DAG.getMemIntrinsicNode(
+ Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType,
+ MachinePointerInfo(), Align(1), MachineMemOperand::MOStore);
+
+ I += NumElts;
}
return DAG.getNode(NVPTXISD::RET_GLUE, dl, MVT::Other, Chain);
>From 70a1de76d2b8b7c69d3199d37b502301eaf65606 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Mon, 28 Apr 2025 16:16:39 +0000
Subject: [PATCH 3/3] [NVPTX][NFC] Final misc. cleanup
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 222 +++++++++-----------
llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 10 +-
2 files changed, 101 insertions(+), 131 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index b8c419bee53ae..b21635f7caf04 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1144,21 +1144,24 @@ NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
std::string NVPTXTargetLowering::getPrototype(
const DataLayout &DL, Type *retTy, const ArgListTy &Args,
- const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
- std::optional<std::pair<unsigned, const APInt &>> VAInfo,
- const CallBase &CB, unsigned UniqueCallSite) const {
+ const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign RetAlign,
+ std::optional<std::pair<unsigned, unsigned>> VAInfo, const CallBase &CB,
+ unsigned UniqueCallSite) const {
auto PtrVT = getPointerTy(DL);
std::string Prototype;
raw_string_ostream O(Prototype);
O << "prototype_" << UniqueCallSite << " : .callprototype ";
- if (retTy->getTypeID() == Type::VoidTyID) {
+ if (retTy->isVoidTy()) {
O << "()";
} else {
O << "(";
- if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
- !shouldPassAsArray(retTy)) {
+ if (shouldPassAsArray(retTy)) {
+ assert(RetAlign && "RetAlign must be set for non-void return types");
+ O << ".param .align " << RetAlign->value() << " .b8 _["
+ << DL.getTypeAllocSize(retTy) << "]";
+ } else if (retTy->isFloatingPointTy() || retTy->isIntegerTy()) {
unsigned size = 0;
if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
size = ITy->getBitWidth();
@@ -1175,9 +1178,6 @@ std::string NVPTXTargetLowering::getPrototype(
O << ".param .b" << size << " _";
} else if (isa<PointerType>(retTy)) {
O << ".param .b" << PtrVT.getSizeInBits() << " _";
- } else if (shouldPassAsArray(retTy)) {
- O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
- << " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
} else {
llvm_unreachable("Unknown return type");
}
@@ -1187,57 +1187,52 @@ std::string NVPTXTargetLowering::getPrototype(
bool first = true;
- unsigned NumArgs = VAInfo ? VAInfo->first : Args.size();
- for (unsigned i = 0, OIdx = 0; i != NumArgs; ++i, ++OIdx) {
- Type *Ty = Args[i].Ty;
+ const unsigned NumArgs = VAInfo ? VAInfo->first : Args.size();
+ auto AllOuts = ArrayRef(Outs);
+ for (const unsigned I : llvm::seq(NumArgs)) {
+ const auto ArgOuts =
+ AllOuts.take_while([I](auto O) { return O.OrigArgIndex == I; });
+ AllOuts = AllOuts.drop_front(ArgOuts.size());
+
+ Type *Ty = Args[I].Ty;
if (!first) {
O << ", ";
}
first = false;
- if (!Outs[OIdx].Flags.isByVal()) {
+ if (ArgOuts[0].Flags.isByVal()) {
+ // Indirect calls need strict ABI alignment so we disable optimizations by
+ // not providing a function to optimize.
+ Type *ETy = Args[I].IndirectType;
+ Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
+ Align ParamByValAlign =
+ getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
+
+ O << ".param .align " << ParamByValAlign.value() << " .b8 _["
+ << ArgOuts[0].Flags.getByValSize() << "]";
+ } else {
if (shouldPassAsArray(Ty)) {
Align ParamAlign =
- getArgumentAlignment(&CB, Ty, i + AttributeList::FirstArgIndex, DL);
- O << ".param .align " << ParamAlign.value() << " .b8 ";
- O << "_";
- O << "[" << DL.getTypeAllocSize(Ty) << "]";
- // update the index for Outs
- SmallVector<EVT, 16> vtparts;
- ComputeValueVTs(*this, DL, Ty, vtparts);
- if (unsigned len = vtparts.size())
- OIdx += len - 1;
+ getArgumentAlignment(&CB, Ty, I + AttributeList::FirstArgIndex, DL);
+ O << ".param .align " << ParamAlign.value() << " .b8 _["
+ << DL.getTypeAllocSize(Ty) << "]";
continue;
}
// i8 types in IR will be i16 types in SDAG
- assert((getValueType(DL, Ty) == Outs[OIdx].VT ||
- (getValueType(DL, Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) &&
+ assert((getValueType(DL, Ty) == ArgOuts[0].VT ||
+ (getValueType(DL, Ty) == MVT::i8 && ArgOuts[0].VT == MVT::i16)) &&
"type mismatch between callee prototype and arguments");
// scalar type
unsigned sz = 0;
- if (isa<IntegerType>(Ty)) {
- sz = cast<IntegerType>(Ty)->getBitWidth();
- sz = promoteScalarArgumentSize(sz);
+ if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
+ sz = promoteScalarArgumentSize(ITy->getBitWidth());
} else if (isa<PointerType>(Ty)) {
sz = PtrVT.getSizeInBits();
} else {
sz = Ty->getPrimitiveSizeInBits();
}
- O << ".param .b" << sz << " ";
- O << "_";
- continue;
+ O << ".param .b" << sz << " _";
}
-
- // Indirect calls need strict ABI alignment so we disable optimizations by
- // not providing a function to optimize.
- Type *ETy = Args[i].IndirectType;
- Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
- Align ParamByValAlign =
- getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
-
- O << ".param .align " << ParamByValAlign.value() << " .b8 ";
- O << "_";
- O << "[" << Outs[OIdx].Flags.getByValSize() << "]";
}
if (VAInfo)
@@ -1420,6 +1415,10 @@ static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
return MachinePointerInfo();
}
+static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags) {
+ return Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+}
+
SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SmallVectorImpl<SDValue> &InVals) const {
@@ -1483,14 +1482,14 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
"Outs and OutVals must be the same size");
// Declare the .params or .reg need to pass values
// to the function
- for (const auto [I, Arg] : llvm::enumerate(Args)) {
- const auto ArgOuts =
- AllOuts.take_while([I = I](auto O) { return O.OrigArgIndex == I; });
+ for (const auto [ArgI, Arg] : llvm::enumerate(Args)) {
+ const auto ArgOuts = AllOuts.take_while(
+ [ArgI = ArgI](auto O) { return O.OrigArgIndex == ArgI; });
const auto ArgOutVals = AllOutVals.take_front(ArgOuts.size());
AllOuts = AllOuts.drop_front(ArgOuts.size());
AllOutVals = AllOutVals.drop_front(ArgOuts.size());
- const bool IsVAArg = (I >= FirstVAArg);
+ const bool IsVAArg = (ArgI >= FirstVAArg);
const bool IsByVal = Arg.IsByVal;
SmallVector<EVT, 16> VTs;
@@ -1514,29 +1513,26 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
if (IsVAArg)
VAOffset = alignTo(VAOffset, ArgAlign);
} else {
- ArgAlign = getArgumentAlignment(CB, Arg.Ty, I + 1, DL);
+ ArgAlign = getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
}
const unsigned TypeSize = DL.getTypeAllocSize(ETy);
assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
"type size mismatch");
- bool NeedAlign; // Does argument declaration specify alignment?
const bool PassAsArray = IsByVal || shouldPassAsArray(Arg.Ty);
if (IsVAArg) {
- if (I == FirstVAArg) {
+ if (ArgI == FirstVAArg) {
VADeclareParam = Chain =
DAG.getNode(NVPTXISD::DeclareParam, dl, {MVT::Other, MVT::Glue},
{Chain, GetI32(STI.getMaxRequiredAlignment()),
- GetI32(I), GetI32(1), InGlue});
+ GetI32(ArgI), GetI32(1), InGlue});
}
- NeedAlign = PassAsArray;
} else if (PassAsArray) {
// declare .param .align <align> .b8 .param<n>[<size>];
Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, {MVT::Other, MVT::Glue},
- {Chain, GetI32(ArgAlign.value()), GetI32(I),
+ {Chain, GetI32(ArgAlign.value()), GetI32(ArgI),
GetI32(TypeSize), InGlue});
- NeedAlign = true;
} else {
assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
// declare .param .b<size> .param<n>;
@@ -1551,8 +1547,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
Chain = DAG.getNode(
NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
- {Chain, GetI32(I), GetI32(PromotedSize), GetI32(0), InGlue});
- NeedAlign = false;
+ {Chain, GetI32(ArgI), GetI32(PromotedSize), GetI32(0), InGlue});
}
InGlue = Chain.getValue(1);
@@ -1564,7 +1559,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
const auto GetStoredValue = [&](const unsigned I, EVT EltVT,
- MaybeAlign PartAlign) {
+ const Align PartAlign) {
SDValue StVal;
if (IsByVal) {
SDValue Ptr = ArgOutVals[0];
@@ -1577,18 +1572,16 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
StVal = ArgOutVals[I];
if (auto PromotedVT = PromoteScalarIntegerPTX(StVal.getValueType())) {
- llvm::ISD::NodeType Ext =
- ArgOuts[I].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
- StVal = DAG.getNode(Ext, dl, *PromotedVT, StVal);
+ StVal = DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, *PromotedVT,
+ StVal);
}
}
if (ExtendIntegerParam) {
assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
// zext/sext to i32
- StVal = DAG.getNode(ArgOuts[I].Flags.isSExt() ? ISD::SIGN_EXTEND
- : ISD::ZERO_EXTEND,
- dl, MVT::i32, StVal);
+ StVal =
+ DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, MVT::i32, StVal);
} else if (EltVT.getSizeInBits() < 16) {
// Use 16-bit registers for small stores as it's the
// smallest general purpose register size supported by NVPTX.
@@ -1604,22 +1597,19 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
for (const unsigned NumElts : VectorInfo) {
const int CurOffset = Offsets[J];
EVT EltVT = VTs[J];
- MaybeAlign PartAlign;
- if (NeedAlign)
- PartAlign = commonAlignment(ArgAlign, CurOffset);
+ const Align PartAlign = commonAlignment(ArgAlign, CurOffset);
if (auto PromotedVT = PromoteScalarIntegerPTX(EltVT))
EltVT = *PromotedVT;
// If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
// scalar store. In such cases, fall back to byte stores.
- if (NumElts == 1 && !IsVAArg && PartAlign.has_value() &&
- PartAlign.value() < DAG.getEVTAlign(EltVT)) {
+ if (NumElts == 1 && !IsVAArg && PartAlign < DAG.getEVTAlign(EltVT)) {
SDValue StVal = GetStoredValue(J, EltVT, PartAlign);
Chain = LowerUnalignedStoreParam(DAG, Chain,
CurOffset + (IsByVal ? VAOffset : 0),
- EltVT, StVal, InGlue, I, dl);
+ EltVT, StVal, InGlue, ArgI, dl);
// LowerUnalignedStoreParam took care of inserting the necessary nodes
// into the SDAG, so just move on to the next element.
@@ -1634,9 +1624,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
assert((IsVAArg || VAOffset == 0) &&
"VAOffset must be 0 for non-VA args");
SmallVector<SDValue, 6> StoreOperands{
- Chain, GetI32(IsVAArg ? FirstVAArg : I),
- GetI32(IsByVal ? CurOffset + VAOffset
- : (IsVAArg ? VAOffset : CurOffset))};
+ Chain, GetI32(IsVAArg ? FirstVAArg : ArgI),
+ GetI32(VAOffset + ((IsVAArg && !IsByVal) ? 0 : CurOffset))};
// Record the values to store.
for (const unsigned K : llvm::seq(NumElts))
@@ -1683,12 +1672,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
}
GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
- MaybeAlign retAlignment = std::nullopt;
+ MaybeAlign RetAlign = std::nullopt;
// Handle Result
if (!Ins.empty()) {
- SmallVector<EVT, 16> resvtparts;
- ComputeValueVTs(*this, DL, RetTy, resvtparts);
+ RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
// Declare
// .param .align N .b8 retval0[<size-in-bytes>], or
@@ -1702,9 +1690,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
DeclareRetOps);
InGlue = Chain.getValue(1);
} else {
- retAlignment = getArgumentAlignment(CB, RetTy, 0, DL);
- assert(retAlignment && "retAlignment is guaranteed to be set");
- SDValue DeclareRetOps[] = {Chain, GetI32(retAlignment->value()),
+ SDValue DeclareRetOps[] = {Chain, GetI32(RetAlign->value()),
GetI32(ResultSize / 8), GetI32(0), InGlue};
Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl,
{MVT::Other, MVT::Glue}, DeclareRetOps);
@@ -1754,10 +1740,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// The prototype is embedded in a string and put as the operand for a
// CallPrototype SDNode which will print out to the value of the string.
std::string Proto = getPrototype(
- DL, RetTy, Args, CLI.Outs, retAlignment,
+ DL, RetTy, Args, CLI.Outs, RetAlign,
HasVAArgs
- ? std::optional<std::pair<unsigned, const APInt &>>(std::make_pair(
- CLI.NumFixedArgs, VADeclareParam->getConstantOperandAPInt(1)))
+ ? std::optional<std::pair<unsigned, unsigned>>(std::make_pair(
+ CLI.NumFixedArgs, VADeclareParam.getConstantOperandVal(1)))
: std::nullopt,
*CB, UniqueCallSite);
const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
@@ -1826,7 +1812,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
}
SmallVector<SDValue, 16> ProxyRegOps;
- SmallVector<std::optional<MVT>, 16> ProxyRegTruncates;
// An item of the vector is filled if the element does not need a ProxyReg
// operation on it and should be added to InVals as is. ProxyRegOps and
// ProxyRegTruncates contain empty/none items at the same index.
@@ -1843,7 +1828,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
assert(VTs.size() == Ins.size() && "Bad value decomposition");
- const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
+ assert(RetAlign && "RetAlign is guaranteed to be set");
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
// 32-bits are sign extended or zero extended, depending on whether
@@ -1851,27 +1836,22 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const bool ExtendIntegerRetVal =
RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
- const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
+ const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, *RetAlign);
unsigned I = 0;
for (const unsigned VectorizedSize : VectorInfo) {
- bool NeedTruncate = false;
EVT TheLoadType = VTs[I];
EVT EltType = Ins[I].VT;
- const Align EltAlign = commonAlignment(RetAlign, Offsets[I]);
+ const Align EltAlign = commonAlignment(*RetAlign, Offsets[I]);
if (auto PromotedVT = PromoteScalarIntegerPTX(TheLoadType)) {
TheLoadType = *PromotedVT;
EltType = *PromotedVT;
- NeedTruncate = true;
}
if (ExtendIntegerRetVal) {
TheLoadType = MVT::i32;
EltType = MVT::i32;
- NeedTruncate = true;
} else if (TheLoadType.getSizeInBits() < 16) {
- if (TheLoadType.isInteger())
- NeedTruncate = true;
EltType = MVT::i16;
}
@@ -1882,7 +1862,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SDValue Ret = LowerUnalignedLoadRetParam(
DAG, Chain, Offsets[I], TheLoadType, InGlue, TempProxyRegOps, dl);
ProxyRegOps.push_back(SDValue());
- ProxyRegTruncates.push_back(std::optional<MVT>());
RetElts.resize(I);
RetElts.push_back(Ret);
@@ -1915,11 +1894,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
for (const unsigned J : llvm::seq(VectorizedSize)) {
ProxyRegOps.push_back(RetVal.getValue(J));
-
- if (NeedTruncate)
- ProxyRegTruncates.push_back(Ins[I + J].VT);
- else
- ProxyRegTruncates.push_back(std::nullopt);
}
Chain = RetVal.getValue(VectorizedSize);
@@ -1950,10 +1924,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
Chain = Ret.getValue(1);
InGlue = Ret.getValue(2);
- if (ProxyRegTruncates[I]) {
- Ret = DAG.getNode(ISD::TRUNCATE, dl, *ProxyRegTruncates[I], Ret);
+ const EVT ExpectedVT = Ins[I].VT;
+ if (!Ret.getValueType().bitsEq(ExpectedVT)) {
+ Ret = DAG.getNode(ISD::TRUNCATE, dl, ExpectedVT, Ret);
}
-
InVals.push_back(Ret);
}
@@ -3385,8 +3359,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
// getLoad needs a vector type, but it can't handle
// vectors which contain v2f16 or v2bf16 elements. So we must load
// using i32 here and then bitcast back.
- if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
- return MVT::i32;
+ if (EltVT.isVector())
+ return MVT::getIntegerVT(EltVT.getFixedSizeInBits());
return EltVT;
}();
@@ -3416,13 +3390,15 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
// Extend or truncate the element if necessary (e.g. an i8 is loaded
// into an i16 register)
const EVT ExpactedVT = ArgIns[I + J].VT;
- if (ExpactedVT.getFixedSizeInBits() !=
- Elt.getValueType().getFixedSizeInBits()) {
- assert(ExpactedVT.isScalarInteger() &&
- Elt.getValueType().isScalarInteger() &&
- "Non-integer argument type size mismatch");
- Elt = DAG.getExtOrTrunc(ArgIns[I + J].Flags.isSExt(), Elt, dl,
- ExpactedVT);
+ assert((Elt.getValueType().bitsEq(ExpactedVT) ||
+ (ExpactedVT.isScalarInteger() &&
+ Elt.getValueType().isScalarInteger())) &&
+ "Non-integer argument type size mismatch");
+ if (ExpactedVT.bitsGT(Elt.getValueType())) {
+ Elt = DAG.getNode(getExtOpcode(ArgIns[I + J].Flags), dl, ExpactedVT,
+ Elt);
+ } else if (ExpactedVT.bitsLT(Elt.getValueType())) {
+ Elt = DAG.getNode(ISD::TRUNCATE, dl, ExpactedVT, Elt);
} else {
// v2f16 was loaded as an i32. Now we must bitcast it back.
Elt = DAG.getBitcast(EltVT, Elt);
@@ -3504,9 +3480,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
"OutVal type should always be legal");
if (ExtendIntegerRetVal) {
- RetVal = DAG.getNode(Outs[I].Flags.isSExt() ? ISD::SIGN_EXTEND
- : ISD::ZERO_EXTEND,
- dl, MVT::i32, RetVal);
+ RetVal = DAG.getNode(getExtOpcode(Outs[I].Flags), dl, MVT::i32, RetVal);
} else if (RetVal.getValueSizeInBits() < 16) {
// Use 16-bit registers for small load-stores as it's the
// smallest general purpose register size supported by NVPTX.
@@ -3515,24 +3489,20 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
return RetVal;
};
- const auto VectorInfo = VectorizePTXValueVTs(
- VTs, Offsets, getFunctionParamOptimizedAlign(&F, RetTy, DL));
+ const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
+ const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
unsigned I = 0;
for (const unsigned NumElts : VectorInfo) {
- if (NumElts == 1 && RetTy->isAggregateType()) {
- const EVT ElementType = VTs[I];
- const Align ElementTypeAlign = DAG.getEVTAlign(ElementType);
- const Align ElementAlign =
- commonAlignment(DL.getABITypeAlign(RetTy), Offsets[I]);
- if (ElementAlign < ElementTypeAlign) {
- Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[I], ElementType,
- GetRetVal(I), dl);
-
- // The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
- // into the graph, so just move on to the next element.
- I++;
- continue;
- }
+ const Align CurrentAlign = commonAlignment(RetAlign, Offsets[I]);
+ if (NumElts == 1 && RetTy->isAggregateType() &&
+ CurrentAlign < DAG.getEVTAlign(VTs[I])) {
+ Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[I], VTs[I],
+ GetRetVal(I), dl);
+
+ // The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
+ // into the graph, so just move on to the next element.
+ I++;
+ continue;
}
SmallVector<SDValue, 6> StoreOperands{
@@ -3561,7 +3531,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[I];
Chain = DAG.getMemIntrinsicNode(
Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType,
- MachinePointerInfo(), Align(1), MachineMemOperand::MOStore);
+ MachinePointerInfo(), CurrentAlign, MachineMemOperand::MOStore);
I += NumElts;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 7a8bf3bf33a94..3279a4c2e74f3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -187,11 +187,11 @@ class NVPTXTargetLowering : public TargetLowering {
SDValue LowerSTACKSAVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSTACKRESTORE(SDValue Op, SelectionDAG &DAG) const;
- std::string
- getPrototype(const DataLayout &DL, Type *, const ArgListTy &,
- const SmallVectorImpl<ISD::OutputArg> &, MaybeAlign retAlignment,
- std::optional<std::pair<unsigned, const APInt &>> VAInfo,
- const CallBase &CB, unsigned UniqueCallSite) const;
+ std::string getPrototype(const DataLayout &DL, Type *, const ArgListTy &,
+ const SmallVectorImpl<ISD::OutputArg> &,
+ MaybeAlign RetAlign,
+ std::optional<std::pair<unsigned, unsigned>> VAInfo,
+ const CallBase &CB, unsigned UniqueCallSite) const;
SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
const SmallVectorImpl<ISD::OutputArg> &Outs,
More information about the llvm-commits
mailing list