[llvm] 619b7ce - [NVPTX] Backend support for variadic functions

Andrew Savonichev via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 13 08:08:09 PST 2022


Author: Pavel Kopyl
Date: 2022-12-13T19:07:43+03:00
New Revision: 619b7cecf35507c6a9d658719732146644a05baf

URL: https://github.com/llvm/llvm-project/commit/619b7cecf35507c6a9d658719732146644a05baf
DIFF: https://github.com/llvm/llvm-project/commit/619b7cecf35507c6a9d658719732146644a05baf.diff

LOG: [NVPTX] Backend support for variadic functions

This patch adds lowering for function calls with variadic number of
arguments as well as enables support for the following
instructions/intrinsics:

  - va_arg
  - va_start
  - va_end
  - va_copy

Note that this patch doesn't intent to include clang's support for
variadic functions for CUDA.

According to the docs:

  PTX version 6.0 supports passing unsized array parameter to a
  function which can be used to implement variadic functions. [0]

  The last parameter in the parameter list may be a .param array of
  type .b8 with no size specified. It is used to pass an arbitrary
  number of parameters to the function packed into a single array
  object.

  When calling a function with such an unsized last argument, the last
  argument may be omitted from the call instruction if no parameter is
  passed through it.  Accesses to this array parameter must be within
  the bounds of the array.  The result of an access is undefined if no
  array was passed, or if the access was outside the bounds of the
  actual array being passed. [1]

Note that aggregates passed by value as variadic arguments are not
currently supported.

[0] https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#variadic-functions
[1] https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#kernel-and-function-directives-func

Differential Revision: https://reviews.llvm.org/D138531

Added: 
    llvm/test/CodeGen/NVPTX/vaargs.ll

Modified: 
    llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/lib/Target/NVPTX/NVPTXISelLowering.h
    llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
    llvm/lib/Target/NVPTX/NVPTXSubtarget.h
    llvm/test/CodeGen/NVPTX/symbol-naming.ll
    llvm/test/DebugInfo/NVPTX/dbg-value-const-byref.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 1326b60beeb97..da1662e29c235 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1466,7 +1466,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
   bool isABI = (STI.getSmVersion() >= 20);
   bool hasImageHandles = STI.hasImageHandles();
 
-  if (F->arg_empty()) {
+  if (F->arg_empty() && !F->isVarArg()) {
     O << "()\n";
     return;
   }
@@ -1670,6 +1670,15 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
     }
   }
 
+  if (F->isVarArg()) {
+    if (!first)
+      O << ",\n";
+    O << "\t.param .align " << STI.getMaxRequiredAlignment();
+    O << " .b8 ";
+    getSymbol(F)->print(O, MAI);
+    O << "_vararg[]";
+  }
+
   O << "\n)\n";
 }
 

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 0709e8977fc8a..1ae52d4fecdbf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -319,12 +319,15 @@ enum ParamVectorizationFlags {
 static SmallVector<ParamVectorizationFlags, 16>
 VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
                      const SmallVectorImpl<uint64_t> &Offsets,
-                     Align ParamAlignment) {
+                     Align ParamAlignment, bool IsVAArg = false) {
   // Set vector size to match ValueVTs and mark all elements as
   // scalars by default.
   SmallVector<ParamVectorizationFlags, 16> VectorInfo;
   VectorInfo.assign(ValueVTs.size(), PVF_SCALAR);
 
+  if (IsVAArg)
+    return VectorInfo;
+
   // Check what we can vectorize using 128/64/32-bit accesses.
   for (int I = 0, E = ValueVTs.size(); I != E; ++I) {
     // Skip elements we've already processed.
@@ -514,6 +517,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
     }
   }
 
+  // Support varargs.
+  setOperationAction(ISD::VASTART, MVT::Other, Custom);
+  setOperationAction(ISD::VAARG, MVT::Other, Custom);
+  setOperationAction(ISD::VACOPY, MVT::Other, Expand);
+  setOperationAction(ISD::VAEND, MVT::Other, Expand);
+
   // Custom handling for i8 intrinsics
   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
 
@@ -1309,7 +1318,8 @@ NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
 std::string NVPTXTargetLowering::getPrototype(
     const DataLayout &DL, Type *retTy, const ArgListTy &Args,
     const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
-    const CallBase &CB, unsigned UniqueCallSite) const {
+    Optional<std::pair<unsigned, const APInt &>> VAInfo, const CallBase &CB,
+    unsigned UniqueCallSite) const {
   auto PtrVT = getPointerTy(DL);
 
   bool isABI = (STI.getSmVersion() >= 20);
@@ -1317,7 +1327,8 @@ std::string NVPTXTargetLowering::getPrototype(
   if (!isABI)
     return "";
 
-  std::stringstream O;
+  std::string Prototype;
+  raw_string_ostream O(Prototype);
   O << "prototype_" << UniqueCallSite << " : .callprototype ";
 
   if (retTy->getTypeID() == Type::VoidTyID) {
@@ -1355,7 +1366,8 @@ std::string NVPTXTargetLowering::getPrototype(
   bool first = true;
 
   const Function *F = CB.getFunction();
-  for (unsigned i = 0, e = Args.size(), OIdx = 0; i != e; ++i, ++OIdx) {
+  unsigned NumArgs = VAInfo ? VAInfo->first : Args.size();
+  for (unsigned i = 0, OIdx = 0; i != NumArgs; ++i, ++OIdx) {
     Type *Ty = Args[i].Ty;
     if (!first) {
       O << ", ";
@@ -1414,8 +1426,13 @@ std::string NVPTXTargetLowering::getPrototype(
     O << "_";
     O << "[" << Outs[OIdx].Flags.getByValSize() << "]";
   }
+
+  if (VAInfo)
+    O << (first ? "" : ",") << " .param .align " << VAInfo->second
+      << " .b8 _[]\n";
   O << ");";
-  return O.str();
+
+  return Prototype;
 }
 
 Align NVPTXTargetLowering::getArgumentAlignment(SDValue Callee,
@@ -1459,6 +1476,12 @@ Align NVPTXTargetLowering::getArgumentAlignment(SDValue Callee,
 
 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                                        SmallVectorImpl<SDValue> &InVals) const {
+
+  if (CLI.IsVarArg && (STI.getPTXVersion() < 60 || STI.getSmVersion() < 30))
+    report_fatal_error(
+        "Support for variadic functions (unsized array parameter) introduced "
+        "in PTX ISA version 6.0 and requires target sm_30.");
+
   SelectionDAG &DAG = CLI.DAG;
   SDLoc dl = CLI.DL;
   SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs;
@@ -1477,6 +1500,26 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   if (!isABI)
     return Chain;
 
+  // Variadic arguments.
+  //
+  // Normally, for each argument, we declare a param scalar or a param
+  // byte array in the .param space, and store the argument value to that
+  // param scalar or array starting at offset 0.
+  //
+  // In the case of the first variadic argument, we declare a vararg byte array
+  // with size 0. The exact size of this array isn't known at this point, so
+  // it'll be patched later. All the variadic arguments will be stored to this
+  // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
+  // initially set to 0, so it can be used for non-variadic arguments (which use
+  // 0 offset) to simplify the code.
+  //
+  // After all vararg is processed, 'VAOffset' holds the size of the
+  // vararg byte array.
+
+  SDValue VADeclareParam;                 // vararg byte array
+  unsigned FirstVAArg = CLI.NumFixedArgs; // position of the first variadic
+  unsigned VAOffset = 0;                  // current offset in the param array
+
   unsigned UniqueCallSite = GlobalUniqueCallSite.fetch_add(1);
   SDValue TempChain = Chain;
   Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
@@ -1498,6 +1541,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
     EVT VT = Outs[OIdx].VT;
     Type *Ty = Args[i].Ty;
+    bool IsVAArg = (i >= CLI.NumFixedArgs);
     bool IsByVal = Outs[OIdx].Flags.isByVal();
 
     SmallVector<EVT, 16> VTs;
@@ -1506,7 +1550,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     assert((!IsByVal || Args[i].IndirectType) &&
            "byval arg must have indirect type");
     Type *ETy = (IsByVal ? Args[i].IndirectType : Ty);
-    ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets);
+    ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset);
 
     Align ArgAlign;
     if (IsByVal) {
@@ -1516,13 +1560,16 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       ArgAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
 
       // Try to increase alignment to enhance vectorization options.
-      ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(
-                                        getMaybeBitcastedCallee(CB), ETy, DL));
+      if (const Function *DirectCallee = CB->getCalledFunction())
+        ArgAlign = std::max(
+            ArgAlign, getFunctionParamOptimizedAlign(DirectCallee, ETy, DL));
 
       // Enforce minumum alignment of 4 to work around ptxas miscompile
       // for sm_50+. See corresponding alignment adjustment in
       // emitFunctionParamList() for details.
       ArgAlign = std::max(ArgAlign, Align(4));
+      if (IsVAArg)
+        VAOffset = alignTo(VAOffset, ArgAlign);
     } else {
       ArgAlign = getArgumentAlignment(Callee, CB, Ty, ParamCount + 1, DL);
     }
@@ -1532,8 +1579,19 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
 
     bool NeedAlign; // Does argument declaration specify alignment?
-    if (IsByVal ||
-        (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128))) {
+    if (IsVAArg) {
+      if (ParamCount == FirstVAArg) {
+        SDValue DeclareParamOps[] = {
+            Chain, DAG.getConstant(STI.getMaxRequiredAlignment(), dl, MVT::i32),
+            DAG.getConstant(ParamCount, dl, MVT::i32),
+            DAG.getConstant(1, dl, MVT::i32), InFlag};
+        VADeclareParam = Chain = DAG.getNode(NVPTXISD::DeclareParam, dl,
+                                             DeclareParamVTs, DeclareParamOps);
+      }
+      NeedAlign = IsByVal || Ty->isAggregateType() || Ty->isVectorTy() ||
+                  Ty->isIntegerTy(128);
+    } else if (IsByVal || Ty->isAggregateType() || Ty->isVectorTy() ||
+               Ty->isIntegerTy(128)) {
       // declare .param .align <align> .b8 .param<n>[<size>];
       SDValue DeclareParamOps[] = {
           Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32),
@@ -1567,7 +1625,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     bool ExtendIntegerParam =
         Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32;
 
-    auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
+    auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
     SmallVector<SDValue, 6> StoreOperands;
     for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
       EVT EltVT = VTs[j];
@@ -1580,8 +1638,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       if (VectorInfo[j] & PVF_FIRST) {
         assert(StoreOperands.empty() && "Unfinished preceding store.");
         StoreOperands.push_back(Chain);
-        StoreOperands.push_back(DAG.getConstant(ParamCount, dl, MVT::i32));
-        StoreOperands.push_back(DAG.getConstant(CurOffset, dl, MVT::i32));
+        StoreOperands.push_back(
+            DAG.getConstant(IsVAArg ? FirstVAArg : ParamCount, dl, MVT::i32));
+        StoreOperands.push_back(DAG.getConstant(
+            IsByVal ? CurOffset + VAOffset : (IsVAArg ? VAOffset : CurOffset),
+            dl, MVT::i32));
       }
 
       SDValue StVal = OutVals[OIdx];
@@ -1650,6 +1711,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
         // Cleanup.
         StoreOperands.clear();
+
+        // TODO: We may need to support vector types that can be passed
+        // as scalars in variadic arguments.
+        if (!IsByVal && IsVAArg) {
+          assert(NumElts == 1 &&
+                 "Vectorization is expected to be disabled for variadics.");
+          VAOffset += DL.getTypeAllocSize(
+              TheStoreType.getTypeForEVT(*DAG.getContext()));
+        }
       }
       if (!IsByVal)
         ++OIdx;
@@ -1658,6 +1728,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     if (!IsByVal && VTs.size() > 0)
       --OIdx;
     ++ParamCount;
+    if (IsByVal && IsVAArg)
+      VAOffset += TypeSize;
   }
 
   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
@@ -1700,6 +1772,18 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     }
   }
 
+  bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
+  // Set the size of the vararg param byte array if the callee is a variadic
+  // function and the variadic part is not empty.
+  if (HasVAArgs) {
+    SDValue DeclareParamOps[] = {
+        VADeclareParam.getOperand(0), VADeclareParam.getOperand(1),
+        VADeclareParam.getOperand(2), DAG.getConstant(VAOffset, dl, MVT::i32),
+        VADeclareParam.getOperand(4)};
+    DAG.MorphNodeTo(VADeclareParam.getNode(), VADeclareParam.getOpcode(),
+                    VADeclareParam->getVTList(), DeclareParamOps);
+  }
+
   // Both indirect calls and libcalls have nullptr Func. In order to distinguish
   // between them we must rely on the call site value which is valid for
   // indirect calls but is always null for libcalls.
@@ -1726,8 +1810,14 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     // The prototype is embedded in a string and put as the operand for a
     // CallPrototype SDNode which will print out to the value of the string.
     SDVTList ProtoVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-    std::string Proto =
-        getPrototype(DL, RetTy, Args, Outs, retAlignment, *CB, UniqueCallSite);
+    std::string Proto = getPrototype(
+        DL, RetTy, Args, Outs, retAlignment,
+        HasVAArgs ? Optional<std::pair<unsigned, const APInt &>>(std::make_pair(
+                        CLI.NumFixedArgs,
+                        cast<ConstantSDNode>(VADeclareParam->getOperand(1))
+                            ->getAPIntValue()))
+                  : std::nullopt,
+        *CB, UniqueCallSite);
     const char *ProtoStr =
       nvTM->getManagedStrPool()->getManagedString(Proto.c_str())->c_str();
     SDValue ProtoOps[] = {
@@ -1762,7 +1852,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                       CallArgBeginOps);
   InFlag = Chain.getValue(1);
 
-  for (unsigned i = 0, e = ParamCount; i != e; ++i) {
+  for (unsigned i = 0, e = std::min(CLI.NumFixedArgs + 1, ParamCount); i != e;
+       ++i) {
     unsigned opcode;
     if (i == (e - 1))
       opcode = NVPTXISD::LastCallArg;
@@ -2235,11 +2326,73 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerSelect(Op, DAG);
   case ISD::FROUND:
     return LowerFROUND(Op, DAG);
+  case ISD::VAARG:
+    return LowerVAARG(Op, DAG);
+  case ISD::VASTART:
+    return LowerVASTART(Op, DAG);
   default:
     llvm_unreachable("Custom lowering not defined for operation");
   }
 }
 
+// This function is almost a copy of SelectionDAG::expandVAArg().
+// The only 
diff  is that this one produces loads from local address space.
+SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
+  const TargetLowering *TLI = STI.getTargetLowering();
+  SDLoc DL(Op);
+
+  SDNode *Node = Op.getNode();
+  const Value *V = cast<SrcValueSDNode>(Node->getOperand(2))->getValue();
+  EVT VT = Node->getValueType(0);
+  auto *Ty = VT.getTypeForEVT(*DAG.getContext());
+  SDValue Tmp1 = Node->getOperand(0);
+  SDValue Tmp2 = Node->getOperand(1);
+  const MaybeAlign MA(Node->getConstantOperandVal(3));
+
+  SDValue VAListLoad = DAG.getLoad(TLI->getPointerTy(DAG.getDataLayout()), DL,
+                                   Tmp1, Tmp2, MachinePointerInfo(V));
+  SDValue VAList = VAListLoad;
+
+  if (MA && *MA > TLI->getMinStackArgumentAlignment()) {
+    VAList = DAG.getNode(
+        ISD::ADD, DL, VAList.getValueType(), VAList,
+        DAG.getConstant(MA->value() - 1, DL, VAList.getValueType()));
+
+    VAList = DAG.getNode(
+        ISD::AND, DL, VAList.getValueType(), VAList,
+        DAG.getConstant(-(int64_t)MA->value(), DL, VAList.getValueType()));
+  }
+
+  // Increment the pointer, VAList, to the next vaarg
+  Tmp1 = DAG.getNode(ISD::ADD, DL, VAList.getValueType(), VAList,
+                     DAG.getConstant(DAG.getDataLayout().getTypeAllocSize(Ty),
+                                     DL, VAList.getValueType()));
+
+  // Store the incremented VAList to the legalized pointer
+  Tmp1 = DAG.getStore(VAListLoad.getValue(1), DL, Tmp1, Tmp2,
+                      MachinePointerInfo(V));
+
+  const Value *SrcV =
+      Constant::getNullValue(PointerType::get(Ty, ADDRESS_SPACE_LOCAL));
+
+  // Load the actual argument out of the pointer VAList
+  return DAG.getLoad(VT, DL, Tmp1, VAList, MachinePointerInfo(SrcV));
+}
+
+SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
+  const TargetLowering *TLI = STI.getTargetLowering();
+  SDLoc DL(Op);
+  EVT PtrVT = TLI->getPointerTy(DAG.getDataLayout());
+
+  // Store the address of unsized array <function>_vararg[] in the ap object.
+  SDValue Arg = getParamSymbol(DAG, /* vararg */ -1, PtrVT);
+  SDValue VAReg = DAG.getNode(NVPTXISD::Wrapper, DL, PtrVT, Arg);
+
+  const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
+  return DAG.getStore(Op.getOperand(0), DL, VAReg, Op.getOperand(1),
+                      MachinePointerInfo(SV));
+}
+
 SDValue NVPTXTargetLowering::LowerSelect(SDValue Op, SelectionDAG &DAG) const {
   SDValue Op0 = Op->getOperand(0);
   SDValue Op1 = Op->getOperand(1);
@@ -2461,13 +2614,21 @@ SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
   return Result;
 }
 
-SDValue
-NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const {
+// This creates target external symbol for a function parameter.
+// Name of the symbol is composed from its index and the function name.
+// Negative index corresponds to special parameter (unsized array) used for
+// passing variable arguments.
+SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx,
+                                            EVT v) const {
   std::string ParamSym;
   raw_string_ostream ParamStr(ParamSym);
 
-  ParamStr << DAG.getMachineFunction().getName() << "_param_" << idx;
-  ParamStr.flush();
+  ParamStr << DAG.getMachineFunction().getName();
+
+  if (idx < 0)
+    ParamStr << "_vararg";
+  else
+    ParamStr << "_param_" << idx;
 
   std::string *SavedStr =
     nvTM->getManagedStrPool()->getManagedString(ParamSym.c_str());
@@ -2652,14 +2813,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
     InVals.push_back(p);
   }
 
-  // Clang will check explicit VarArg and issue error if any. However, Clang
-  // will let code with
-  // implicit var arg like f() pass. See bug 617733.
-  // We treat this case as if the arg list is empty.
-  // if (F.isVarArg()) {
-  // assert(0 && "VarArg not supported yet!");
-  //}
-
   if (!OutChains.empty())
     DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains));
 

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index ae66816548f9f..3c088428718be 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -501,8 +501,9 @@ class NVPTXTargetLowering : public TargetLowering {
 
   std::string getPrototype(const DataLayout &DL, Type *, const ArgListTy &,
                            const SmallVectorImpl<ISD::OutputArg> &,
-                           MaybeAlign retAlignment, const CallBase &CB,
-                           unsigned UniqueCallSite) const;
+                           MaybeAlign retAlignment,
+                           Optional<std::pair<unsigned, const APInt &>> VAInfo,
+                           const CallBase &CB, unsigned UniqueCallSite) const;
 
   SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
                       const SmallVectorImpl<ISD::OutputArg> &Outs,
@@ -596,6 +597,9 @@ class NVPTXTargetLowering : public TargetLowering {
 
   SDValue LowerSelect(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;
+
   void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
                           SelectionDAG &DAG) const override;
   SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override;

diff  --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 934aad66956d2..a114d92397c91 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -195,7 +195,6 @@ class ValueToRegClass<ValueType T> {
 } 
 
 
-
 //===----------------------------------------------------------------------===//
 // Some Common Instruction Class Templates
 //===----------------------------------------------------------------------===//
@@ -1729,7 +1728,7 @@ def IMOV16ri : NVPTXInst<(outs Int16Regs:$dst), (ins i16imm:$src),
 def IMOV32ri : NVPTXInst<(outs Int32Regs:$dst), (ins i32imm:$src),
                          "mov.u32 \t$dst, $src;",
                          [(set Int32Regs:$dst, imm:$src)]>;
-def IMOV64i : NVPTXInst<(outs Int64Regs:$dst), (ins i64imm:$src),
+def IMOV64ri : NVPTXInst<(outs Int64Regs:$dst), (ins i64imm:$src),
                         "mov.u64 \t$dst, $src;",
                         [(set Int64Regs:$dst, imm:$src)]>;
 
@@ -1741,6 +1740,7 @@ def FMOV64ri : NVPTXInst<(outs Float64Regs:$dst), (ins f64imm:$src),
                          [(set Float64Regs:$dst, fpimm:$src)]>;
 
 def : Pat<(i32 (Wrapper texternalsym:$dst)), (IMOV32ri texternalsym:$dst)>;
+def : Pat<(i64 (Wrapper texternalsym:$dst)), (IMOV64ri texternalsym:$dst)>;
 
 //---- Copy Frame Index ----
 def LEA_ADDRi :   NVPTXInst<(outs Int32Regs:$dst), (ins MEMri:$addr),

diff  --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index cea3dce3f1c55..73866ff3027d5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -81,6 +81,15 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
   unsigned int getSmVersion() const { return SmVersion; }
   std::string getTargetName() const { return TargetName; }
 
+  // Get maximum value of required alignments among the supported data types.
+  // From the PTX ISA doc, section 8.2.3:
+  //  The memory consistency model relates operations executed on memory
+  //  locations with scalar data-types, which have a maximum size and alignment
+  //  of 64 bits. Memory operations with a vector data-type are modelled as a
+  //  set of equivalent memory operations with a scalar data-type, executed in
+  //  an unspecified order on the elements in the vector.
+  unsigned getMaxRequiredAlignment() const { return 8; }
+
   unsigned getPTXVersion() const { return PTXVersion; }
 
   NVPTXSubtarget &initializeSubtargetDependencies(StringRef CPU, StringRef FS);

diff  --git a/llvm/test/CodeGen/NVPTX/symbol-naming.ll b/llvm/test/CodeGen/NVPTX/symbol-naming.ll
index 68046167e7c47..d78f47a340795 100644
--- a/llvm/test/CodeGen/NVPTX/symbol-naming.ll
+++ b/llvm/test/CodeGen/NVPTX/symbol-naming.ll
@@ -1,7 +1,7 @@
-; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
-; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 | FileCheck %s
-; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_20 | %ptxas-verify %}
-; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_20 | %ptxas-verify %}
+; RUN: llc < %s -march=nvptx -mattr=+ptx60 -mcpu=sm_30 | FileCheck %s
+; RUN: llc < %s -march=nvptx64 -mattr=+ptx60 -mcpu=sm_30 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx -mattr=+ptx60 -mcpu=sm_30 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mattr=+ptx60 -mcpu=sm_30 | %ptxas-verify %}
 
 ; Verify that the NVPTX target removes invalid symbol names prior to emitting
 ; PTX.

diff  --git a/llvm/test/CodeGen/NVPTX/vaargs.ll b/llvm/test/CodeGen/NVPTX/vaargs.ll
new file mode 100644
index 0000000000000..de8f7074b70be
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/vaargs.ll
@@ -0,0 +1,112 @@
+; RUN: llc < %s -O0 -march=nvptx -mattr=+ptx60 -mcpu=sm_30 | FileCheck %s --check-prefixes=CHECK,CHECK32
+; RUN: llc < %s -O0 -march=nvptx64 -mattr=+ptx60 -mcpu=sm_30 | FileCheck %s --check-prefixes=CHECK,CHECK64
+; RUN: %if ptxas %{ llc < %s -O0 -march=nvptx -mattr=+ptx60 -mcpu=sm_30 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -O0 -march=nvptx64 -mattr=+ptx60 -mcpu=sm_30 | %ptxas-verify %}
+
+; CHECK: .address_size [[BITS:32|64]]
+
+%struct.__va_list_tag = type { i8*, i8*, i32, i32 }
+
+ at foo_ptr = internal addrspace(1) global i32 (i32, ...)* @foo, align 8
+
+define i32 @foo(i32 %a, ...) {
+entry:
+  %al = alloca [1 x %struct.__va_list_tag], align 8
+  %ap = bitcast [1 x %struct.__va_list_tag]* %al to i8*
+  %al2 = alloca [1 x %struct.__va_list_tag], align 8
+  %ap2 = bitcast [1 x %struct.__va_list_tag]* %al2 to i8*
+
+; Test va_start
+; CHECK:         .param .align 8 .b8 foo_vararg[]
+; CHECK:         mov.u[[BITS]] [[VA_PTR:%(r|rd)[0-9]+]], foo_vararg;
+; CHECK-NEXT:    st.u[[BITS]] [%SP+0], [[VA_PTR]];
+
+  call void @llvm.va_start(i8* %ap)
+
+; Test va_copy()
+; CHECK-NEXT:	 ld.u[[BITS]] [[VA_PTR:%(r|rd)[0-9]+]], [%SP+0];
+; CHECK-NEXT:	 st.u[[BITS]] [%SP+{{[0-9]+}}], [[VA_PTR]];
+
+  call void @llvm.va_copy(i8* %ap2, i8* %ap)
+
+; Test va_arg(ap, int32_t)
+; CHECK-NEXT:    ld.u[[BITS]] [[VA_PTR:%(r|rd)[0-9]+]], [%SP+0];
+; CHECK-NEXT:    add.s[[BITS]] [[VA_PTR_TMP:%(r|rd)[0-9]+]], [[VA_PTR]], 3;
+; CHECK-NEXT:    and.b[[BITS]] [[VA_PTR_ALIGN:%(r|rd)[0-9]+]], [[VA_PTR_TMP]], -4;
+; CHECK-NEXT:    add.s[[BITS]] [[VA_PTR_NEXT:%(r|rd)[0-9]+]], [[VA_PTR_ALIGN]], 4;
+; CHECK-NEXT:    st.u[[BITS]] [%SP+0], [[VA_PTR_NEXT]];
+; CHECK-NEXT:    ld.local.u32 %r{{[0-9]+}}, [[[VA_PTR_ALIGN]]];
+
+  %0 = va_arg i8* %ap, i32
+
+; Test va_arg(ap, int64_t)
+; CHECK-NEXT:    ld.u[[BITS]] [[VA_PTR:%(r|rd)[0-9]+]], [%SP+0];
+; CHECK-NEXT:    add.s[[BITS]] [[VA_PTR_TMP:%(r|rd)[0-9]+]], [[VA_PTR]], 7;
+; CHECK-NEXT:    and.b[[BITS]] [[VA_PTR_ALIGN:%(r|rd)[0-9]+]], [[VA_PTR_TMP]], -8;
+; CHECK-NEXT:    add.s[[BITS]] [[VA_PTR_NEXT:%(r|rd)[0-9]+]], [[VA_PTR_ALIGN]], 8;
+; CHECK-NEXT:    st.u[[BITS]] [%SP+0], [[VA_PTR_NEXT]];
+; CHECK-NEXT:    ld.local.u64 %rd{{[0-9]+}}, [[[VA_PTR_ALIGN]]];
+
+  %1 = va_arg i8* %ap, i64
+
+; Test va_arg(ap, double)
+; CHECK-NEXT:    ld.u[[BITS]] [[VA_PTR:%(r|rd)[0-9]+]], [%SP+0];
+; CHECK-NEXT:    add.s[[BITS]] [[VA_PTR_TMP:%(r|rd)[0-9]+]], [[VA_PTR]], 7;
+; CHECK-NEXT:    and.b[[BITS]] [[VA_PTR_ALIGN:%(r|rd)[0-9]+]], [[VA_PTR_TMP]], -8;
+; CHECK-NEXT:    add.s[[BITS]] [[VA_PTR_NEXT:%(r|rd)[0-9]+]], [[VA_PTR_ALIGN]], 8;
+; CHECK-NEXT:    st.u[[BITS]] [%SP+0], [[VA_PTR_NEXT]];
+; CHECK-NEXT:    ld.local.f64 %fd{{[0-9]+}}, [[[VA_PTR_ALIGN]]];
+
+  %2 = va_arg i8* %ap, double
+
+; Test va_arg(ap, void *)
+; CHECK-NEXT:    ld.u[[BITS]] [[VA_PTR:%(r|rd)[0-9]+]], [%SP+0];
+; CHECK32-NEXT:  add.s32 [[VA_PTR_TMP:%r[0-9]+]], [[VA_PTR]], 3;
+; CHECK64-NEXT:  add.s64 [[VA_PTR_TMP:%rd[0-9]+]], [[VA_PTR]], 7;
+; CHECK32-NEXT:  and.b32 [[VA_PTR_ALIGN:%r[0-9]+]], [[VA_PTR_TMP]], -4;
+; CHECK64-NEXT:  and.b64 [[VA_PTR_ALIGN:%rd[0-9]+]], [[VA_PTR_TMP]], -8;
+; CHECK32-NEXT:  add.s32 [[VA_PTR_NEXT:%r[0-9]+]], [[VA_PTR_ALIGN]], 4;
+; CHECK64-NEXT:  add.s64 [[VA_PTR_NEXT:%rd[0-9]+]], [[VA_PTR_ALIGN]], 8;
+; CHECK-NEXT:    st.u[[BITS]] [%SP+0], [[VA_PTR_NEXT]];
+; CHECK-NEXT:    ld.local.u[[BITS]] %{{(r|rd)[0-9]+}}, [[[VA_PTR_ALIGN]]];
+
+  %3 = va_arg i8* %ap, i8*
+  %call = call i32 @bar(i32 %a, i32 %0, i64 %1, double %2, i8* %3)
+
+  call void @llvm.va_end(i8* %ap)
+  %4 =  va_arg i8* %ap2, i32
+  call void @llvm.va_end(i8* %ap2)
+  %5 = add i32 %call, %4
+  ret i32 %5
+}
+
+define i32 @test_foo(i32 %i, i64 %l, double %d, i8* %p) {
+; Test indirect variadic function call.
+
+; Load arguments to temporary variables
+; CHECK32:       ld.param.u32 [[ARG_VOID_PTR:%r[0-9]+]], [test_foo_param_3];
+; CHECK64:       ld.param.u64 [[ARG_VOID_PTR:%rd[0-9]+]], [test_foo_param_3];
+; CHECK-NEXT:    ld.param.f64 [[ARG_DOUBLE:%fd[0-9]+]], [test_foo_param_2];
+; CHECK-NEXT:    ld.param.u64 [[ARG_I64:%rd[0-9]+]], [test_foo_param_1];
+; CHECK-NEXT:    ld.param.u32 [[ARG_I32:%r[0-9]+]], [test_foo_param_0];
+
+; Store arguments to an array
+; CHECK32:  .param .align 8 .b8 param1[24];
+; CHECK64:  .param .align 8 .b8 param1[28];
+; CHECK-NEXT:    st.param.b32 [param1+0], [[ARG_I32]];
+; CHECK-NEXT:    st.param.b64 [param1+4], [[ARG_I64]];
+; CHECK-NEXT:    st.param.f64 [param1+12], [[ARG_DOUBLE]];
+; CHECK-NEXT:    st.param.b[[BITS]] [param1+20], [[ARG_VOID_PTR]];
+; CHECK-NEXT:    .param .b32 retval0;
+; CHECK-NEXT:    prototype_1 : .callprototype (.param .b32 _) _ (.param .b32 _, .param .align 8 .b8 _[]
+
+entry:
+  %ptr = load i32 (i32, ...)*, i32 (i32, ...)** addrspacecast (i32 (i32, ...)* addrspace(1)* @foo_ptr to i32 (i32, ...)**), align 8
+  %call = call i32 (i32, ...) %ptr(i32 4, i32 %i, i64 %l, double %d, i8* %p)
+  ret i32 %call
+}
+
+declare void @llvm.va_start(i8*)
+declare void @llvm.va_end(i8*)
+declare void @llvm.va_copy(i8*, i8*)
+declare i32 @bar(i32, i32, i64, double, i8*)

diff  --git a/llvm/test/DebugInfo/NVPTX/dbg-value-const-byref.ll b/llvm/test/DebugInfo/NVPTX/dbg-value-const-byref.ll
index ef589999b16a3..d96d5629d03f0 100644
--- a/llvm/test/DebugInfo/NVPTX/dbg-value-const-byref.ll
+++ b/llvm/test/DebugInfo/NVPTX/dbg-value-const-byref.ll
@@ -31,7 +31,7 @@ entry:
   call void @llvm.dbg.value(metadata i32 3, metadata !10, metadata !DIExpression()), !dbg !15
   %call = call i32 @f3(i32 3) #3, !dbg !16
   call void @llvm.dbg.value(metadata i32 7, metadata !10, metadata !DIExpression()), !dbg !18
-  %call1 = call i32 (...) @f1() #3, !dbg !19
+  %call1 = call i32 @f1() #3, !dbg !19
   call void @llvm.dbg.value(metadata i32 %call1, metadata !10, metadata !DIExpression()), !dbg !19
   store i32 %call1, ptr %i, align 4, !dbg !19, !tbaa !20
   call void @llvm.dbg.value(metadata ptr %i, metadata !10, metadata !DIExpression(DW_OP_deref)), !dbg !24
@@ -41,7 +41,7 @@ entry:
 
 declare i32 @f3(i32)
 
-declare i32 @f1(...)
+declare i32 @f1()
 
 declare void @f2(ptr)
 


        


More information about the llvm-commits mailing list