[llvm] [NVPTX] Vectorize loads when lowering of byval calls, misc. cleanup (PR #151070)
Artem Belevich via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 29 16:46:11 PDT 2025
================
@@ -1476,139 +1524,144 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const SDValue ParamSymbol =
getCallParamSymbol(DAG, IsVAArg ? FirstVAArg : ArgI, MVT::i32);
- SmallVector<EVT, 16> VTs;
- SmallVector<uint64_t, 16> Offsets;
-
assert((!IsByVal || Arg.IndirectType) &&
"byval arg must have indirect type");
Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
- ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset);
- assert(VTs.size() == Offsets.size() && "Size mismatch");
- assert((IsByVal || VTs.size() == ArgOuts.size()) && "Size mismatch");
const Align ArgAlign = [&]() {
if (IsByVal) {
// The ByValAlign in the Outs[OIdx].Flags is always set at this point,
// so we don't need to worry whether it's naturally aligned or not.
// See TargetLowering::LowerCallTo().
const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
- const Align ByValAlign = getFunctionByValParamAlign(
- CB->getCalledFunction(), ETy, InitialAlign, DL);
- if (IsVAArg)
- VAOffset = alignTo(VAOffset, ByValAlign);
- return ByValAlign;
+ return getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
+ InitialAlign, DL);
}
return getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
}();
- const unsigned TypeSize = DL.getTypeAllocSize(ETy);
- assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
+ const unsigned TySize = DL.getTypeAllocSize(ETy);
+ assert((!IsByVal || TySize == ArgOuts[0].Flags.getByValSize()) &&
"type size mismatch");
const SDValue ArgDeclare = [&]() {
if (IsVAArg)
return VADeclareParam;
if (IsByVal || shouldPassAsArray(Arg.Ty))
- return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TypeSize);
+ return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TySize);
assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) &&
"Only int and float types are supported as non-array arguments");
- return MakeDeclareScalarParam(ParamSymbol, TypeSize);
+ return MakeDeclareScalarParam(ParamSymbol, TySize);
}();
- // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
- // than 32-bits are sign extended or zero extended, depending on
- // whether they are signed or unsigned types. This case applies
- // only to scalar parameters and not to aggregate values.
- const bool ExtendIntegerParam =
- Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
+ if (IsByVal) {
+ assert(ArgOutVals.size() == 1 && "We must pass only one value as byval");
+ SDValue SrcPtr = ArgOutVals[0];
+ const auto PointerInfo = refinePtrAS(SrcPtr, DAG, DL, *this);
+ const Align BaseSrcAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
- const auto GetStoredValue = [&](const unsigned I, EVT EltVT,
- const MaybeAlign PartAlign) {
- if (IsByVal) {
- SDValue Ptr = ArgOutVals[0];
- auto MPI = refinePtrAS(Ptr, DAG, DL, *this);
- SDValue SrcAddr =
- DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(Offsets[I]));
-
- return DAG.getLoad(EltVT, dl, CallChain, SrcAddr, MPI, PartAlign);
+ if (IsVAArg)
+ VAOffset = alignTo(VAOffset, ArgAlign);
+
+ SmallVector<EVT, 4> ValueVTs, MemVTs;
+ SmallVector<TypeSize, 4> Offsets;
+ ComputeValueVTs(*this, DL, ETy, ValueVTs, &MemVTs, &Offsets);
+
+ unsigned J = 0;
+ const auto VI = VectorizePTXValueVTs(MemVTs, Offsets, ArgAlign, IsVAArg);
+ for (const unsigned NumElts : VI) {
+ EVT LoadVT = getVectorizedVT(MemVTs[J], NumElts, Ctx);
+ Align SrcAlign = commonAlignment(BaseSrcAlign, Offsets[J]);
+ SDValue SrcAddr = DAG.getObjectPtrOffset(dl, SrcPtr, Offsets[J]);
+ SDValue SrcLoad =
+ DAG.getLoad(LoadVT, dl, CallChain, SrcAddr, PointerInfo, SrcAlign);
+
+ TypeSize ParamOffset = Offsets[J].getWithIncrement(VAOffset);
+ Align ParamAlign = commonAlignment(ArgAlign, ParamOffset);
+ SDValue ParamAddr =
+ DAG.getObjectPtrOffset(dl, ParamSymbol, ParamOffset);
+ SDValue StoreParam =
+ DAG.getStore(ArgDeclare, dl, SrcLoad, ParamAddr,
+ MachinePointerInfo(ADDRESS_SPACE_PARAM), ParamAlign);
+ CallPrereqs.push_back(StoreParam);
+
+ J += NumElts;
}
- SDValue StVal = ArgOutVals[I];
- assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
- StVal.getValueType() &&
- "OutVal type should always be legal");
-
- const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
- const EVT StoreVT =
- ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
-
- return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl);
- };
-
- const auto VectorInfo =
- VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
-
- unsigned J = 0;
- for (const unsigned NumElts : VectorInfo) {
- const int CurOffset = Offsets[J];
- const EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
-
- if (IsVAArg && !IsByVal)
- // Align each part of the variadic argument to their type.
- VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
-
- assert((IsVAArg || VAOffset == 0) &&
- "VAOffset must be 0 for non-VA args");
+ if (IsVAArg)
+ VAOffset += TySize;
+ } else {
+ SmallVector<EVT, 16> VTs;
+ SmallVector<uint64_t, 16> Offsets;
+ ComputePTXValueVTs(*this, DL, Arg.Ty, VTs, &Offsets, VAOffset);
+ assert(VTs.size() == Offsets.size() && "Size mismatch");
+ assert(VTs.size() == ArgOuts.size() && "Size mismatch");
+
+ // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
+ // than 32-bits are sign extended or zero extended, depending on
+ // whether they are signed or unsigned types. This case applies
+ // only to scalar parameters and not to aggregate values.
+ const bool ExtendIntegerParam =
+ Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
+
+ const auto GetStoredValue = [&](const unsigned I) {
+ SDValue StVal = ArgOutVals[I];
+ assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
+ StVal.getValueType() &&
+ "OutVal type should always be legal");
+
+ const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
+ const EVT StoreVT =
+ ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
+
+ return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl);
+ };
+
+ unsigned J = 0;
+ const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
+ for (const unsigned NumElts : VI) {
+ const EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
+
+ unsigned Offset;
+ if (IsVAArg) {
+ // TODO: We may need to support vector types that can be passed
+ // as scalars in variadic arguments.
+ assert(NumElts == 1 &&
+ "Vectorization should be disabled for vaargs.");
+
+ // Align each part of the variadic argument to their type.
+ VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
+ Offset = VAOffset;
+
+ const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
+ VAOffset += DL.getTypeAllocSize(TheStoreType.getTypeForEVT(Ctx));
+ } else {
+ assert(VAOffset == 0 && "VAOffset must be 0 for non-VA args");
+ Offset = Offsets[J];
+ }
- const unsigned Offset =
- (VAOffset + ((IsVAArg && !IsByVal) ? 0 : CurOffset));
- SDValue Ptr =
- DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset));
+ SDValue Ptr =
+ DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset));
- const MaybeAlign CurrentAlign = ExtendIntegerParam
- ? MaybeAlign(std::nullopt)
- : commonAlignment(ArgAlign, Offset);
+ const MaybeAlign CurrentAlign = ExtendIntegerParam
+ ? MaybeAlign(std::nullopt)
+ : commonAlignment(ArgAlign, Offset);
- SDValue Val;
- if (NumElts == 1) {
- Val = GetStoredValue(J, EltVT, CurrentAlign);
- } else {
- SmallVector<SDValue, 8> StoreVals;
- for (const unsigned K : llvm::seq(NumElts)) {
- SDValue ValJ = GetStoredValue(J + K, EltVT, CurrentAlign);
- if (ValJ.getValueType().isVector())
- DAG.ExtractVectorElements(ValJ, StoreVals);
- else
- StoreVals.push_back(ValJ);
- }
+ SDValue Val = getBuildVectorizedValue(
+ [&](unsigned K) { return GetStoredValue(J + K); }, NumElts, dl,
----------------
Artem-B wrote:
Nit: I'd use `[=]` here to indicate that we're not intending to modify the captured value.
https://github.com/llvm/llvm-project/pull/151070
More information about the llvm-commits
mailing list