[llvm] 5bf86d9 - [NVPTX] Remove code duplication in LowerCall
Daniil Kovalev via llvm-commits
llvm-commits at lists.llvm.org
Fri Mar 25 02:43:09 PDT 2022
Author: Daniil Kovalev
Date: 2022-03-25T12:36:20+03:00
New Revision: 5bf86d9e88fa841f5f50f4b8e3b337191691a45d
URL: https://github.com/llvm/llvm-project/commit/5bf86d9e88fa841f5f50f4b8e3b337191691a45d
DIFF: https://github.com/llvm/llvm-project/commit/5bf86d9e88fa841f5f50f4b8e3b337191691a45d.diff
LOG: [NVPTX] Remove code duplication in LowerCall
In D120129 we enhanced vectorization options of byval parameters. This patch
removes code duplication when handling byval and non-byval cases.
Differential Revision: https://reviews.llvm.org/D122381
Added:
Modified:
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 382e83dbb4cb9..11fc25722fcd6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1441,11 +1441,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
return Chain;
unsigned UniqueCallSite = GlobalUniqueCallSite.fetch_add(1);
- SDValue tempChain = Chain;
+ SDValue TempChain = Chain;
Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
SDValue InFlag = Chain.getValue(1);
- unsigned paramCount = 0;
+ unsigned ParamCount = 0;
// Args.size() and Outs.size() need not match.
// Outs.size() will be larger
// * if there is an aggregate argument with multiple fields (each field
@@ -1461,185 +1461,115 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
EVT VT = Outs[OIdx].VT;
Type *Ty = Args[i].Ty;
+ bool IsByVal = Outs[OIdx].Flags.isByVal();
- if (!Outs[OIdx].Flags.isByVal()) {
- SmallVector<EVT, 16> VTs;
- SmallVector<uint64_t, 16> Offsets;
- ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets);
- Align ArgAlign = getArgumentAlignment(Callee, CB, Ty, paramCount + 1, DL);
- unsigned AllocSize = DL.getTypeAllocSize(Ty);
- SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- bool NeedAlign; // Does argument declaration specify alignment?
- if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
- // declare .param .align <align> .b8 .param<n>[<size>];
- SDValue DeclareParamOps[] = {
- Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32),
- DAG.getConstant(paramCount, dl, MVT::i32),
- DAG.getConstant(AllocSize, dl, MVT::i32), InFlag};
- Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
- DeclareParamOps);
- NeedAlign = true;
- } else {
- // declare .param .b<size> .param<n>;
- if ((VT.isInteger() || VT.isFloatingPoint()) && AllocSize < 4) {
- // PTX ABI requires integral types to be at least 32 bits in
- // size. FP16 is loaded/stored using i16, so it's handled
- // here as well.
- AllocSize = 4;
- }
- SDValue DeclareScalarParamOps[] = {
- Chain, DAG.getConstant(paramCount, dl, MVT::i32),
- DAG.getConstant(AllocSize * 8, dl, MVT::i32),
- DAG.getConstant(0, dl, MVT::i32), InFlag};
- Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
- DeclareScalarParamOps);
- NeedAlign = false;
- }
- InFlag = Chain.getValue(1);
-
- // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
- // than 32-bits are sign extended or zero extended, depending on
- // whether they are signed or unsigned types. This case applies
- // only to scalar parameters and not to aggregate values.
- bool ExtendIntegerParam =
- Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32;
-
- auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
- SmallVector<SDValue, 6> StoreOperands;
- for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
- // New store.
- if (VectorInfo[j] & PVF_FIRST) {
- assert(StoreOperands.empty() && "Unfinished preceding store.");
- StoreOperands.push_back(Chain);
- StoreOperands.push_back(DAG.getConstant(paramCount, dl, MVT::i32));
- StoreOperands.push_back(DAG.getConstant(Offsets[j], dl, MVT::i32));
- }
-
- EVT EltVT = VTs[j];
- SDValue StVal = OutVals[OIdx];
- if (ExtendIntegerParam) {
- assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
- // zext/sext to i32
- StVal = DAG.getNode(Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
- : ISD::ZERO_EXTEND,
- dl, MVT::i32, StVal);
- } else if (EltVT.getSizeInBits() < 16) {
- // Use 16-bit registers for small stores as it's the
- // smallest general purpose register size supported by NVPTX.
- StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
- }
-
- // Record the value to store.
- StoreOperands.push_back(StVal);
-
- if (VectorInfo[j] & PVF_LAST) {
- unsigned NumElts = StoreOperands.size() - 3;
- NVPTXISD::NodeType Op;
- switch (NumElts) {
- case 1:
- Op = NVPTXISD::StoreParam;
- break;
- case 2:
- Op = NVPTXISD::StoreParamV2;
- break;
- case 4:
- Op = NVPTXISD::StoreParamV4;
- break;
- default:
- llvm_unreachable("Invalid vector info.");
- }
-
- StoreOperands.push_back(InFlag);
-
- // Adjust type of the store op if we've extended the scalar
- // return value.
- EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : VTs[j];
- MaybeAlign EltAlign;
- if (NeedAlign)
- EltAlign = commonAlignment(ArgAlign, Offsets[j]);
-
- Chain = DAG.getMemIntrinsicNode(
- Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands,
- TheStoreType, MachinePointerInfo(), EltAlign,
- MachineMemOperand::MOStore);
- InFlag = Chain.getValue(1);
-
- // Cleanup.
- StoreOperands.clear();
- }
- ++OIdx;
- }
- assert(StoreOperands.empty() && "Unfinished parameter store.");
- if (VTs.size() > 0)
- --OIdx;
- ++paramCount;
- continue;
- }
-
- // ByVal arguments
- // TODO: remove code duplication when handling byval and non-byval cases.
SmallVector<EVT, 16> VTs;
SmallVector<uint64_t, 16> Offsets;
- Type *ETy = Args[i].IndirectType;
- assert(ETy && "byval arg must have indirect type");
- ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, 0);
- // declare .param .align <align> .b8 .param<n>[<size>];
- unsigned sz = Outs[OIdx].Flags.getByValSize();
- SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
+ assert((!IsByVal || Args[i].IndirectType) &&
+ "byval arg must have indirect type");
+ Type *ETy = (IsByVal ? Args[i].IndirectType : Ty);
+ ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets);
+
+ Align ArgAlign;
+ if (IsByVal) {
+ // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
+ // so we don't need to worry whether it's naturally aligned or not.
+ // See TargetLowering::LowerCallTo().
+ ArgAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
+
+ // Try to increase alignment to enhance vectorization options.
+ ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(
+ CB->getCalledFunction(), ETy, DL));
+
+ // Enforce minumum alignment of 4 to work around ptxas miscompile
+ // for sm_50+. See corresponding alignment adjustment in
+ // emitFunctionParamList() for details.
+ ArgAlign = std::max(ArgAlign, Align(4));
+ } else {
+ ArgAlign = getArgumentAlignment(Callee, CB, Ty, ParamCount + 1, DL);
+ }
- // The ByValAlign in the Outs[OIdx].Flags is alway set at this point,
- // so we don't need to worry about natural alignment or not.
- // See TargetLowering::LowerCallTo().
- Align ArgAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
+ unsigned TypeSize =
+ (IsByVal ? Outs[OIdx].Flags.getByValSize() : DL.getTypeAllocSize(Ty));
+ SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- // Try to increase alignment to enhance vectorization options.
- const Function *F = CB->getCalledFunction();
- Align AlignCandidate = getFunctionParamOptimizedAlign(F, ETy, DL);
- ArgAlign = std::max(ArgAlign, AlignCandidate);
-
- // Enforce minumum alignment of 4 to work around ptxas miscompile
- // for sm_50+. See corresponding alignment adjustment in
- // emitFunctionParamList() for details.
- if (ArgAlign < Align(4))
- ArgAlign = Align(4);
- SDValue DeclareParamOps[] = {
- Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32),
- DAG.getConstant(paramCount, dl, MVT::i32),
- DAG.getConstant(sz, dl, MVT::i32), InFlag};
- Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
- DeclareParamOps);
+ bool NeedAlign; // Does argument declaration specify alignment?
+ if (IsByVal ||
+ (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128))) {
+ // declare .param .align <align> .b8 .param<n>[<size>];
+ SDValue DeclareParamOps[] = {
+ Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32),
+ DAG.getConstant(ParamCount, dl, MVT::i32),
+ DAG.getConstant(TypeSize, dl, MVT::i32), InFlag};
+ Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
+ DeclareParamOps);
+ NeedAlign = true;
+ } else {
+ // declare .param .b<size> .param<n>;
+ if ((VT.isInteger() || VT.isFloatingPoint()) && TypeSize < 4) {
+ // PTX ABI requires integral types to be at least 32 bits in
+ // size. FP16 is loaded/stored using i16, so it's handled
+ // here as well.
+ TypeSize = 4;
+ }
+ SDValue DeclareScalarParamOps[] = {
+ Chain, DAG.getConstant(ParamCount, dl, MVT::i32),
+ DAG.getConstant(TypeSize * 8, dl, MVT::i32),
+ DAG.getConstant(0, dl, MVT::i32), InFlag};
+ Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
+ DeclareScalarParamOps);
+ NeedAlign = false;
+ }
InFlag = Chain.getValue(1);
+ // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
+ // than 32-bits are sign extended or zero extended, depending on
+ // whether they are signed or unsigned types. This case applies
+ // only to scalar parameters and not to aggregate values.
+ bool ExtendIntegerParam =
+ Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32;
+
auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
SmallVector<SDValue, 6> StoreOperands;
for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
- EVT elemtype = VTs[j];
- int curOffset = Offsets[j];
- Align PartAlign = commonAlignment(ArgAlign, curOffset);
+ EVT EltVT = VTs[j];
+ int CurOffset = Offsets[j];
+ MaybeAlign PartAlign;
+ if (NeedAlign)
+ PartAlign = commonAlignment(ArgAlign, CurOffset);
// New store.
if (VectorInfo[j] & PVF_FIRST) {
assert(StoreOperands.empty() && "Unfinished preceding store.");
StoreOperands.push_back(Chain);
- StoreOperands.push_back(DAG.getConstant(paramCount, dl, MVT::i32));
- StoreOperands.push_back(DAG.getConstant(curOffset, dl, MVT::i32));
+ StoreOperands.push_back(DAG.getConstant(ParamCount, dl, MVT::i32));
+ StoreOperands.push_back(DAG.getConstant(CurOffset, dl, MVT::i32));
}
- auto PtrVT = getPointerTy(DL);
- SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, OutVals[OIdx],
- DAG.getConstant(curOffset, dl, PtrVT));
- SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
- MachinePointerInfo(), PartAlign);
+ SDValue StVal = OutVals[OIdx];
+ if (IsByVal) {
+ auto PtrVT = getPointerTy(DL);
+ SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
+ DAG.getConstant(CurOffset, dl, PtrVT));
+ StVal = DAG.getLoad(EltVT, dl, TempChain, srcAddr, MachinePointerInfo(),
+ PartAlign);
+ } else if (ExtendIntegerParam) {
+ assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
+ // zext/sext to i32
+ StVal = DAG.getNode(Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
+ : ISD::ZERO_EXTEND,
+ dl, MVT::i32, StVal);
+ }
- if (elemtype.getSizeInBits() < 16) {
+ if (!ExtendIntegerParam && EltVT.getSizeInBits() < 16) {
// Use 16-bit registers for small stores as it's the
// smallest general purpose register size supported by NVPTX.
- theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal);
+ StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
}
// Record the value to store.
- StoreOperands.push_back(theVal);
+ StoreOperands.push_back(StVal);
if (VectorInfo[j] & PVF_LAST) {
unsigned NumElts = StoreOperands.size() - 3;
@@ -1660,18 +1590,26 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
StoreOperands.push_back(InFlag);
+ // Adjust type of the store op if we've extended the scalar
+ // return value.
+ EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
+
Chain = DAG.getMemIntrinsicNode(
Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands,
- elemtype, MachinePointerInfo(), PartAlign,
+ TheStoreType, MachinePointerInfo(), PartAlign,
MachineMemOperand::MOStore);
InFlag = Chain.getValue(1);
// Cleanup.
StoreOperands.clear();
}
+ if (!IsByVal)
+ ++OIdx;
}
assert(StoreOperands.empty() && "Unfinished parameter store.");
- ++paramCount;
+ if (!IsByVal && VTs.size() > 0)
+ --OIdx;
+ ++ParamCount;
}
GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
@@ -1778,7 +1716,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
CallArgBeginOps);
InFlag = Chain.getValue(1);
- for (unsigned i = 0, e = paramCount; i != e; ++i) {
+ for (unsigned i = 0, e = ParamCount; i != e; ++i) {
unsigned opcode;
if (i == (e - 1))
opcode = NVPTXISD::LastCallArg;
More information about the llvm-commits
mailing list