[llvm] r211938 - [NVPTX] Clean up argument lowering code and properly handle alignment for structs and vectors
Justin Holewinski
jholewinski at nvidia.com
Fri Jun 27 11:35:45 PDT 2014
Author: jholewinski
Date: Fri Jun 27 13:35:44 2014
New Revision: 211938
URL: http://llvm.org/viewvc/llvm-project?rev=211938&view=rev
Log:
[NVPTX] Clean up argument lowering code and properly handle alignment for structs and vectors
Added:
llvm/trunk/test/CodeGen/NVPTX/arg-lowering.ll
Modified:
llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp?rev=211938&r1=211937&r2=211938&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp Fri Jun 27 13:35:44 2014
@@ -67,6 +67,17 @@ static bool IsPTXVectorType(MVT VT) {
}
}
+static uint64_t GCD( int a, int b)
+{
+ if (a < b) std::swap(a,b);
+ while (b > 0) {
+ uint64_t c = b;
+ b = a % b;
+ a = c;
+ }
+ return a;
+}
+
/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
/// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors
/// into their primitive components.
@@ -518,26 +529,12 @@ NVPTXTargetLowering::getPrototype(Type *
} else if (isa<PointerType>(retTy)) {
O << ".param .b" << getPointerTy().getSizeInBits() << " _";
} else {
- if ((retTy->getTypeID() == Type::StructTyID) || isa<VectorType>(retTy)) {
- SmallVector<EVT, 16> vtparts;
- ComputeValueVTs(*this, retTy, vtparts);
- unsigned totalsz = 0;
- for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
- unsigned elems = 1;
- EVT elemtype = vtparts[i];
- if (vtparts[i].isVector()) {
- elems = vtparts[i].getVectorNumElements();
- elemtype = vtparts[i].getVectorElementType();
- }
- // TODO: no need to loop
- for (unsigned j = 0, je = elems; j != je; ++j) {
- unsigned sz = elemtype.getSizeInBits();
- if (elemtype.isInteger() && (sz < 8))
- sz = 8;
- totalsz += sz / 8;
- }
- }
- O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
+ if((retTy->getTypeID() == Type::StructTyID) ||
+ isa<VectorType>(retTy)) {
+ O << ".param .align "
+ << retAlignment
+ << " .b8 _["
+ << getDataLayout()->getTypeAllocSize(retTy) << "]";
} else {
assert(false && "Unknown return type");
}
@@ -706,7 +703,8 @@ SDValue NVPTXTargetLowering::LowerCall(T
if (Ty->isAggregateType()) {
// aggregate
SmallVector<EVT, 16> vtparts;
- ComputeValueVTs(*this, Ty, vtparts);
+ SmallVector<uint64_t, 16> Offsets;
+ ComputePTXValueVTs(*this, Ty, vtparts, &Offsets, 0);
unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
// declare .param .align <align> .b8 .param<n>[<size>];
@@ -718,34 +716,26 @@ SDValue NVPTXTargetLowering::LowerCall(T
Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
DeclareParamOps);
InFlag = Chain.getValue(1);
- unsigned curOffset = 0;
for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
- unsigned elems = 1;
EVT elemtype = vtparts[j];
- if (vtparts[j].isVector()) {
- elems = vtparts[j].getVectorNumElements();
- elemtype = vtparts[j].getVectorElementType();
- }
- for (unsigned k = 0, ke = elems; k != ke; ++k) {
- unsigned sz = elemtype.getSizeInBits();
- if (elemtype.isInteger() && (sz < 8))
- sz = 8;
- SDValue StVal = OutVals[OIdx];
- if (elemtype.getSizeInBits() < 16) {
- StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
- }
- SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- SDValue CopyParamOps[] = { Chain,
- DAG.getConstant(paramCount, MVT::i32),
- DAG.getConstant(curOffset, MVT::i32),
- StVal, InFlag };
- Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
- CopyParamVTs, CopyParamOps,
- elemtype, MachinePointerInfo());
- InFlag = Chain.getValue(1);
- curOffset += sz / 8;
- ++OIdx;
+ unsigned ArgAlign = GCD(align, Offsets[j]);
+ if (elemtype.isInteger() && (sz < 8))
+ sz = 8;
+ SDValue StVal = OutVals[OIdx];
+ if (elemtype.getSizeInBits() < 16) {
+ StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
}
+ SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
+ SDValue CopyParamOps[] = { Chain,
+ DAG.getConstant(paramCount, MVT::i32),
+ DAG.getConstant(Offsets[j], MVT::i32),
+ StVal, InFlag };
+ Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
+ CopyParamVTs, CopyParamOps,
+ elemtype, MachinePointerInfo(),
+ ArgAlign);
+ InFlag = Chain.getValue(1);
+ ++OIdx;
}
if (vtparts.size() > 0)
--OIdx;
@@ -930,13 +920,15 @@ SDValue NVPTXTargetLowering::LowerCall(T
}
// struct or vector
SmallVector<EVT, 16> vtparts;
+ SmallVector<uint64_t, 16> Offsets;
const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
assert(PTy && "Type of a byval parameter should be pointer");
- ComputeValueVTs(*this, PTy->getElementType(), vtparts);
+ ComputePTXValueVTs(*this, PTy->getElementType(), vtparts, &Offsets, 0);
// declare .param .align <align> .b8 .param<n>[<size>];
unsigned sz = Outs[OIdx].Flags.getByValSize();
SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
+ unsigned ArgAlign = Outs[OIdx].Flags.getByValAlign();
// The ByValAlign in the Outs[OIdx].Flags is alway set at this point,
// so we don't need to worry about natural alignment or not.
// See TargetLowering::LowerCallTo().
@@ -948,38 +940,28 @@ SDValue NVPTXTargetLowering::LowerCall(T
Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
DeclareParamOps);
InFlag = Chain.getValue(1);
- unsigned curOffset = 0;
for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
- unsigned elems = 1;
EVT elemtype = vtparts[j];
- if (vtparts[j].isVector()) {
- elems = vtparts[j].getVectorNumElements();
- elemtype = vtparts[j].getVectorElementType();
+ int curOffset = Offsets[j];
+ unsigned PartAlign = GCD(ArgAlign, curOffset);
+ SDValue srcAddr =
+ DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
+ DAG.getConstant(curOffset, getPointerTy()));
+ SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
+ MachinePointerInfo(), false, false, false,
+ PartAlign);
+ if (elemtype.getSizeInBits() < 16) {
+ theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal);
}
- for (unsigned k = 0, ke = elems; k != ke; ++k) {
- unsigned sz = elemtype.getSizeInBits();
- if (elemtype.isInteger() && (sz < 8))
- sz = 8;
- SDValue srcAddr =
- DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
- DAG.getConstant(curOffset, getPointerTy()));
- SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
- MachinePointerInfo(), false, false, false,
- 0);
- if (elemtype.getSizeInBits() < 16) {
- theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal);
- }
- SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
- DAG.getConstant(curOffset, MVT::i32), theVal,
- InFlag };
- Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
- CopyParamOps, elemtype,
- MachinePointerInfo());
+ SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
+ SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
+ DAG.getConstant(curOffset, MVT::i32), theVal,
+ InFlag };
+ Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
+ CopyParamOps, elemtype,
+ MachinePointerInfo());
- InFlag = Chain.getValue(1);
- curOffset += sz / 8;
- }
+ InFlag = Chain.getValue(1);
}
++paramCount;
}
@@ -1088,7 +1070,6 @@ SDValue NVPTXTargetLowering::LowerCall(T
// Generate loads from param memory/moves from registers for result
if (Ins.size() > 0) {
- unsigned resoffset = 0;
if (retTy && retTy->isVectorTy()) {
EVT ObjectVT = getValueType(retTy);
unsigned NumElts = ObjectVT.getVectorNumElements();
@@ -1097,14 +1078,15 @@ SDValue NVPTXTargetLowering::LowerCall(T
ObjectVT) == NumElts &&
"Vector was not scalarized");
unsigned sz = EltVT.getSizeInBits();
- bool needTruncate = sz < 16 ? true : false;
+ bool needTruncate = sz < 8 ? true : false;
if (NumElts == 1) {
// Just a simple load
SmallVector<EVT, 4> LoadRetVTs;
- if (needTruncate) {
- // If loading i1 result, generate
- // load i16
+ if (EltVT == MVT::i1 || EltVT == MVT::i8) {
+ // If loading i1/i8 result, generate
+ // load.b8 i16
+ // if i1
// trunc i16 to i1
LoadRetVTs.push_back(MVT::i16);
} else
@@ -1128,9 +1110,10 @@ SDValue NVPTXTargetLowering::LowerCall(T
} else if (NumElts == 2) {
// LoadV2
SmallVector<EVT, 4> LoadRetVTs;
- if (needTruncate) {
- // If loading i1 result, generate
- // load i16
+ if (EltVT == MVT::i1 || EltVT == MVT::i8) {
+ // If loading i1/i8 result, generate
+ // load.b8 i16
+ // if i1
// trunc i16 to i1
LoadRetVTs.push_back(MVT::i16);
LoadRetVTs.push_back(MVT::i16);
@@ -1173,9 +1156,10 @@ SDValue NVPTXTargetLowering::LowerCall(T
EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
for (unsigned i = 0; i < NumElts; i += VecSize) {
SmallVector<EVT, 8> LoadRetVTs;
- if (needTruncate) {
- // If loading i1 result, generate
- // load i16
+ if (EltVT == MVT::i1 || EltVT == MVT::i8) {
+ // If loading i1/i8 result, generate
+ // load.b8 i16
+ // if i1
// trunc i16 to i1
for (unsigned j = 0; j < VecSize; ++j)
LoadRetVTs.push_back(MVT::i16);
@@ -1214,10 +1198,13 @@ SDValue NVPTXTargetLowering::LowerCall(T
}
} else {
SmallVector<EVT, 16> VTs;
- ComputePTXValueVTs(*this, retTy, VTs);
+ SmallVector<uint64_t, 16> Offsets;
+ ComputePTXValueVTs(*this, retTy, VTs, &Offsets, 0);
assert(VTs.size() == Ins.size() && "Bad value decomposition");
+ unsigned RetAlign = getArgumentAlignment(Callee, CS, retTy, 0);
for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
unsigned sz = VTs[i].getSizeInBits();
+ unsigned AlignI = GCD(RetAlign, Offsets[i]);
bool needTruncate = sz < 8 ? true : false;
if (VTs[i].isInteger() && (sz < 8))
sz = 8;
@@ -1243,19 +1230,18 @@ SDValue NVPTXTargetLowering::LowerCall(T
SmallVector<SDValue, 4> LoadRetOps;
LoadRetOps.push_back(Chain);
LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
- LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32));
+ LoadRetOps.push_back(DAG.getConstant(Offsets[i], MVT::i32));
LoadRetOps.push_back(InFlag);
SDValue retval = DAG.getMemIntrinsicNode(
NVPTXISD::LoadParam, dl,
DAG.getVTList(LoadRetVTs), LoadRetOps,
- TheLoadType, MachinePointerInfo());
+ TheLoadType, MachinePointerInfo(), AlignI);
Chain = retval.getValue(1);
InFlag = retval.getValue(2);
SDValue Ret0 = retval.getValue(0);
if (needTruncate)
Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0);
InVals.push_back(Ret0);
- resoffset += sz / 8;
}
}
}
Added: llvm/trunk/test/CodeGen/NVPTX/arg-lowering.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/NVPTX/arg-lowering.ll?rev=211938&view=auto
==============================================================================
--- llvm/trunk/test/CodeGen/NVPTX/arg-lowering.ll (added)
+++ llvm/trunk/test/CodeGen/NVPTX/arg-lowering.ll Fri Jun 27 13:35:44 2014
@@ -0,0 +1,13 @@
+; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
+
+; CHECK: .visible .func (.param .align 16 .b8 func_retval0[16]) foo0(
+; CHECK: .param .align 4 .b8 foo0_param_0[8]
+define <4 x float> @foo0({float, float} %arg0) {
+ ret <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>
+}
+
+; CHECK: .visible .func (.param .align 8 .b8 func_retval0[8]) foo1(
+; CHECK: .param .align 8 .b8 foo1_param_0[16]
+define <2 x float> @foo1({float, float, i64} %arg0) {
+ ret <2 x float> <float 1.0, float 1.0>
+}
More information about the llvm-commits
mailing list