[llvm] r211938 - [NVPTX] Clean up argument lowering code and properly handle alignment for structs and vectors

Hal Finkel hfinkel at anl.gov
Fri Jun 27 12:05:33 PDT 2014


----- Original Message -----
> From: "Justin Holewinski" <jholewinski at nvidia.com>
> To: llvm-commits at cs.uiuc.edu
> Sent: Friday, June 27, 2014 1:35:45 PM
> Subject: [llvm] r211938 - [NVPTX] Clean up argument lowering code and properly	handle alignment for structs and
> vectors
> 
> Author: jholewinski
> Date: Fri Jun 27 13:35:44 2014
> New Revision: 211938
> 
> URL: http://llvm.org/viewvc/llvm-project?rev=211938&view=rev
> Log:
> [NVPTX] Clean up argument lowering code and properly handle alignment
> for structs and vectors
> 
> Added:
>     llvm/trunk/test/CodeGen/NVPTX/arg-lowering.ll
> Modified:
>     llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
> 
> Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
> URL:
> http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp?rev=211938&r1=211937&r2=211938&view=diff
> ==============================================================================
> --- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp (original)
> +++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp Fri Jun 27
> 13:35:44 2014
> @@ -67,6 +67,17 @@ static bool IsPTXVectorType(MVT VT) {
>    }
>  }
>  
> +static uint64_t GCD( int a, int b)
> +{
> +  if (a < b) std::swap(a,b);
> +  while (b > 0) {
> +    uint64_t c = b;
> +    b = a % b;
> +    a = c;
> +  }
> +  return a;
> +}

There is a GreatestCommonDivisor64 in include/llvm/Support/MathExtras.h, can you use that instead?

 -Hal

> +
>  /// ComputePTXValueVTs - For the given Type \p Ty, returns the set
>  of primitive
>  /// EVTs that compose it.  Unlike ComputeValueVTs, this will break
>  apart vectors
>  /// into their primitive components.
> @@ -518,26 +529,12 @@ NVPTXTargetLowering::getPrototype(Type *
>      } else if (isa<PointerType>(retTy)) {
>        O << ".param .b" << getPointerTy().getSizeInBits() << " _";
>      } else {
> -      if ((retTy->getTypeID() == Type::StructTyID) ||
> isa<VectorType>(retTy)) {
> -        SmallVector<EVT, 16> vtparts;
> -        ComputeValueVTs(*this, retTy, vtparts);
> -        unsigned totalsz = 0;
> -        for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
> -          unsigned elems = 1;
> -          EVT elemtype = vtparts[i];
> -          if (vtparts[i].isVector()) {
> -            elems = vtparts[i].getVectorNumElements();
> -            elemtype = vtparts[i].getVectorElementType();
> -          }
> -          // TODO: no need to loop
> -          for (unsigned j = 0, je = elems; j != je; ++j) {
> -            unsigned sz = elemtype.getSizeInBits();
> -            if (elemtype.isInteger() && (sz < 8))
> -              sz = 8;
> -            totalsz += sz / 8;
> -          }
> -        }
> -        O << ".param .align " << retAlignment << " .b8 _[" <<
> totalsz << "]";
> +      if((retTy->getTypeID() == Type::StructTyID) ||
> +         isa<VectorType>(retTy)) {
> +        O << ".param .align "
> +          << retAlignment
> +          << " .b8 _["
> +          << getDataLayout()->getTypeAllocSize(retTy) << "]";
>        } else {
>          assert(false && "Unknown return type");
>        }
> @@ -706,7 +703,8 @@ SDValue NVPTXTargetLowering::LowerCall(T
>        if (Ty->isAggregateType()) {
>          // aggregate
>          SmallVector<EVT, 16> vtparts;
> -        ComputeValueVTs(*this, Ty, vtparts);
> +        SmallVector<uint64_t, 16> Offsets;
> +        ComputePTXValueVTs(*this, Ty, vtparts, &Offsets, 0);
>  
>          unsigned align = getArgumentAlignment(Callee, CS, Ty,
>          paramCount + 1);
>          // declare .param .align <align> .b8 .param<n>[<size>];
> @@ -718,34 +716,26 @@ SDValue NVPTXTargetLowering::LowerCall(T
>          Chain = DAG.getNode(NVPTXISD::DeclareParam, dl,
>          DeclareParamVTs,
>                              DeclareParamOps);
>          InFlag = Chain.getValue(1);
> -        unsigned curOffset = 0;
>          for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
> -          unsigned elems = 1;
>            EVT elemtype = vtparts[j];
> -          if (vtparts[j].isVector()) {
> -            elems = vtparts[j].getVectorNumElements();
> -            elemtype = vtparts[j].getVectorElementType();
> -          }
> -          for (unsigned k = 0, ke = elems; k != ke; ++k) {
> -            unsigned sz = elemtype.getSizeInBits();
> -            if (elemtype.isInteger() && (sz < 8))
> -              sz = 8;
> -            SDValue StVal = OutVals[OIdx];
> -            if (elemtype.getSizeInBits() < 16) {
> -              StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16,
> StVal);
> -            }
> -            SDVTList CopyParamVTs = DAG.getVTList(MVT::Other,
> MVT::Glue);
> -            SDValue CopyParamOps[] = { Chain,
> -                                       DAG.getConstant(paramCount,
> MVT::i32),
> -                                       DAG.getConstant(curOffset,
> MVT::i32),
> -                                       StVal, InFlag };
> -            Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam,
> dl,
> -                                            CopyParamVTs,
> CopyParamOps,
> -                                            elemtype,
> MachinePointerInfo());
> -            InFlag = Chain.getValue(1);
> -            curOffset += sz / 8;
> -            ++OIdx;
> +          unsigned ArgAlign = GCD(align, Offsets[j]);
> +          if (elemtype.isInteger() && (sz < 8))
> +            sz = 8;
> +          SDValue StVal = OutVals[OIdx];
> +          if (elemtype.getSizeInBits() < 16) {
> +            StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16,
> StVal);
>            }
> +          SDVTList CopyParamVTs = DAG.getVTList(MVT::Other,
> MVT::Glue);
> +          SDValue CopyParamOps[] = { Chain,
> +                                     DAG.getConstant(paramCount,
> MVT::i32),
> +                                     DAG.getConstant(Offsets[j],
> MVT::i32),
> +                                     StVal, InFlag };
> +          Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
> +                                          CopyParamVTs,
> CopyParamOps,
> +                                          elemtype,
> MachinePointerInfo(),
> +                                          ArgAlign);
> +          InFlag = Chain.getValue(1);
> +          ++OIdx;
>          }
>          if (vtparts.size() > 0)
>            --OIdx;
> @@ -930,13 +920,15 @@ SDValue NVPTXTargetLowering::LowerCall(T
>      }
>      // struct or vector
>      SmallVector<EVT, 16> vtparts;
> +    SmallVector<uint64_t, 16> Offsets;
>      const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
>      assert(PTy && "Type of a byval parameter should be pointer");
> -    ComputeValueVTs(*this, PTy->getElementType(), vtparts);
> +    ComputePTXValueVTs(*this, PTy->getElementType(), vtparts,
> &Offsets, 0);
>  
>      // declare .param .align <align> .b8 .param<n>[<size>];
>      unsigned sz = Outs[OIdx].Flags.getByValSize();
>      SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
> +    unsigned ArgAlign = Outs[OIdx].Flags.getByValAlign();
>      // The ByValAlign in the Outs[OIdx].Flags is alway set at this
>      point,
>      // so we don't need to worry about natural alignment or not.
>      // See TargetLowering::LowerCallTo().
> @@ -948,38 +940,28 @@ SDValue NVPTXTargetLowering::LowerCall(T
>      Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
>                          DeclareParamOps);
>      InFlag = Chain.getValue(1);
> -    unsigned curOffset = 0;
>      for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
> -      unsigned elems = 1;
>        EVT elemtype = vtparts[j];
> -      if (vtparts[j].isVector()) {
> -        elems = vtparts[j].getVectorNumElements();
> -        elemtype = vtparts[j].getVectorElementType();
> +      int curOffset = Offsets[j];
> +      unsigned PartAlign = GCD(ArgAlign, curOffset);
> +      SDValue srcAddr =
> +          DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
> +                      DAG.getConstant(curOffset, getPointerTy()));
> +      SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
> +                                   MachinePointerInfo(), false,
> false, false,
> +                                   PartAlign);
> +      if (elemtype.getSizeInBits() < 16) {
> +        theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal);
>        }
> -      for (unsigned k = 0, ke = elems; k != ke; ++k) {
> -        unsigned sz = elemtype.getSizeInBits();
> -        if (elemtype.isInteger() && (sz < 8))
> -          sz = 8;
> -        SDValue srcAddr =
> -            DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
> -                        DAG.getConstant(curOffset, getPointerTy()));
> -        SDValue theVal = DAG.getLoad(elemtype, dl, tempChain,
> srcAddr,
> -                                     MachinePointerInfo(), false,
> false, false,
> -                                     0);
> -        if (elemtype.getSizeInBits() < 16) {
> -          theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16,
> theVal);
> -        }
> -        SDVTList CopyParamVTs = DAG.getVTList(MVT::Other,
> MVT::Glue);
> -        SDValue CopyParamOps[] = { Chain,
> DAG.getConstant(paramCount, MVT::i32),
> -                                   DAG.getConstant(curOffset,
> MVT::i32), theVal,
> -                                   InFlag };
> -        Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
> CopyParamVTs,
> -                                        CopyParamOps, elemtype,
> -                                        MachinePointerInfo());
> +      SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
> +      SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount,
> MVT::i32),
> +                                 DAG.getConstant(curOffset,
> MVT::i32), theVal,
> +                                 InFlag };
> +      Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
> CopyParamVTs,
> +                                      CopyParamOps, elemtype,
> +                                      MachinePointerInfo());
>  
> -        InFlag = Chain.getValue(1);
> -        curOffset += sz / 8;
> -      }
> +      InFlag = Chain.getValue(1);
>      }
>      ++paramCount;
>    }
> @@ -1088,7 +1070,6 @@ SDValue NVPTXTargetLowering::LowerCall(T
>  
>    // Generate loads from param memory/moves from registers for
>    result
>    if (Ins.size() > 0) {
> -    unsigned resoffset = 0;
>      if (retTy && retTy->isVectorTy()) {
>        EVT ObjectVT = getValueType(retTy);
>        unsigned NumElts = ObjectVT.getVectorNumElements();
> @@ -1097,14 +1078,15 @@ SDValue NVPTXTargetLowering::LowerCall(T
>                                                          ObjectVT) ==
>                                                          NumElts &&
>               "Vector was not scalarized");
>        unsigned sz = EltVT.getSizeInBits();
> -      bool needTruncate = sz < 16 ? true : false;
> +      bool needTruncate = sz < 8 ? true : false;
>  
>        if (NumElts == 1) {
>          // Just a simple load
>          SmallVector<EVT, 4> LoadRetVTs;
> -        if (needTruncate) {
> -          // If loading i1 result, generate
> -          //   load i16
> +        if (EltVT == MVT::i1 || EltVT == MVT::i8) {
> +          // If loading i1/i8 result, generate
> +          //   load.b8 i16
> +          //   if i1
>            //   trunc i16 to i1
>            LoadRetVTs.push_back(MVT::i16);
>          } else
> @@ -1128,9 +1110,10 @@ SDValue NVPTXTargetLowering::LowerCall(T
>        } else if (NumElts == 2) {
>          // LoadV2
>          SmallVector<EVT, 4> LoadRetVTs;
> -        if (needTruncate) {
> -          // If loading i1 result, generate
> -          //   load i16
> +        if (EltVT == MVT::i1 || EltVT == MVT::i8) {
> +          // If loading i1/i8 result, generate
> +          //   load.b8 i16
> +          //   if i1
>            //   trunc i16 to i1
>            LoadRetVTs.push_back(MVT::i16);
>            LoadRetVTs.push_back(MVT::i16);
> @@ -1173,9 +1156,10 @@ SDValue NVPTXTargetLowering::LowerCall(T
>          EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT,
>          VecSize);
>          for (unsigned i = 0; i < NumElts; i += VecSize) {
>            SmallVector<EVT, 8> LoadRetVTs;
> -          if (needTruncate) {
> -            // If loading i1 result, generate
> -            //   load i16
> +          if (EltVT == MVT::i1 || EltVT == MVT::i8) {
> +            // If loading i1/i8 result, generate
> +            //   load.b8 i16
> +            //   if i1
>              //   trunc i16 to i1
>              for (unsigned j = 0; j < VecSize; ++j)
>                LoadRetVTs.push_back(MVT::i16);
> @@ -1214,10 +1198,13 @@ SDValue NVPTXTargetLowering::LowerCall(T
>        }
>      } else {
>        SmallVector<EVT, 16> VTs;
> -      ComputePTXValueVTs(*this, retTy, VTs);
> +      SmallVector<uint64_t, 16> Offsets;
> +      ComputePTXValueVTs(*this, retTy, VTs, &Offsets, 0);
>        assert(VTs.size() == Ins.size() && "Bad value decomposition");
> +      unsigned RetAlign = getArgumentAlignment(Callee, CS, retTy,
> 0);
>        for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
>          unsigned sz = VTs[i].getSizeInBits();
> +        unsigned AlignI = GCD(RetAlign, Offsets[i]);
>          bool needTruncate = sz < 8 ? true : false;
>          if (VTs[i].isInteger() && (sz < 8))
>            sz = 8;
> @@ -1243,19 +1230,18 @@ SDValue NVPTXTargetLowering::LowerCall(T
>          SmallVector<SDValue, 4> LoadRetOps;
>          LoadRetOps.push_back(Chain);
>          LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
> -        LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32));
> +        LoadRetOps.push_back(DAG.getConstant(Offsets[i], MVT::i32));
>          LoadRetOps.push_back(InFlag);
>          SDValue retval = DAG.getMemIntrinsicNode(
>              NVPTXISD::LoadParam, dl,
>              DAG.getVTList(LoadRetVTs), LoadRetOps,
> -            TheLoadType, MachinePointerInfo());
> +            TheLoadType, MachinePointerInfo(), AlignI);
>          Chain = retval.getValue(1);
>          InFlag = retval.getValue(2);
>          SDValue Ret0 = retval.getValue(0);
>          if (needTruncate)
>            Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0);
>          InVals.push_back(Ret0);
> -        resoffset += sz / 8;
>        }
>      }
>    }
> 
> Added: llvm/trunk/test/CodeGen/NVPTX/arg-lowering.ll
> URL:
> http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/NVPTX/arg-lowering.ll?rev=211938&view=auto
> ==============================================================================
> --- llvm/trunk/test/CodeGen/NVPTX/arg-lowering.ll (added)
> +++ llvm/trunk/test/CodeGen/NVPTX/arg-lowering.ll Fri Jun 27 13:35:44
> 2014
> @@ -0,0 +1,13 @@
> +; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
> +
> +; CHECK: .visible .func  (.param .align 16 .b8 func_retval0[16])
> foo0(
> +; CHECK:          .param .align 4 .b8 foo0_param_0[8]
> +define <4 x float> @foo0({float, float} %arg0) {
> +  ret <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>
> +}
> +
> +; CHECK: .visible .func  (.param .align 8 .b8 func_retval0[8]) foo1(
> +; CHECK:          .param .align 8 .b8 foo1_param_0[16]
> +define <2 x float> @foo1({float, float, i64} %arg0) {
> +  ret <2 x float> <float 1.0, float 1.0>
> +}
> 
> 
> _______________________________________________
> llvm-commits mailing list
> llvm-commits at cs.uiuc.edu
> http://lists.cs.uiuc.edu/mailman/listinfo/llvm-commits
> 

-- 
Hal Finkel
Assistant Computational Scientist
Leadership Computing Facility
Argonne National Laboratory



More information about the llvm-commits mailing list