[llvm-branch-commits] [llvm] e28c71d - Revert "[NVPTX] Rip out vestigial variadic support (NFC) (#202385)"

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jun 16 02:59:47 PDT 2026


Author: Dmitry Vasilyev
Date: 2026-06-16T13:59:44+04:00
New Revision: e28c71d953dcdb0eacaf174a214f3cdc2ca09263

URL: https://github.com/llvm/llvm-project/commit/e28c71d953dcdb0eacaf174a214f3cdc2ca09263
DIFF: https://github.com/llvm/llvm-project/commit/e28c71d953dcdb0eacaf174a214f3cdc2ca09263.diff

LOG: Revert "[NVPTX] Rip out vestigial variadic support (NFC) (#202385)"

This reverts commit e63cd40ccce67f9472af9676185d7c87157043b4.

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/lib/Target/NVPTX/NVPTXISelLowering.h
    llvm/lib/Target/NVPTX/NVPTXSubtarget.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index cd5c27cebc182..b2efcb0f0d2b6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1360,21 +1360,25 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
   const NVPTXMachineFunctionInfo *MFI =
       MF ? MF->getInfo<NVPTXMachineFunctionInfo>() : nullptr;
 
+  bool IsFirst = true;
   const bool IsKernelFunc = isKernelFunction(*F);
 
-  assert(!F->isVarArg() && "VarArg functions lowered in ExpandVariadics");
-
-  if (F->arg_empty()) {
+  if (F->arg_empty() && !F->isVarArg()) {
     O << "()";
     return;
   }
 
   O << "(\n";
 
-  auto EmitParam = [&](const Argument &Arg) {
+  for (const Argument &Arg : F->args()) {
     Type *Ty = Arg.getType();
     const std::string ParamSym = TLI->getParamName(F, Arg.getArgNo());
 
+    if (!IsFirst)
+      O << ",\n";
+
+    IsFirst = false;
+
     // Handle image/sampler parameters
     if (IsKernelFunc) {
       const PTXOpaqueType ArgOpaqueType = getPTXOpaqueType(Arg);
@@ -1398,7 +1402,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
           llvm_unreachable("handled above");
         }
         O << ParamSym;
-        return;
+        continue;
       }
     }
 
@@ -1420,7 +1424,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
 
       O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
         << "[" << DL.getTypeAllocSize(ETy) << "]";
-      return;
+      continue;
     }
 
     if (shouldPassAsArray(Ty)) {
@@ -1434,7 +1438,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
       O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
         << "[" << DL.getTypeAllocSize(Ty) << "]";
 
-      return;
+      continue;
     }
     // Just a scalar
     auto *PTy = dyn_cast<PointerType>(Ty);
@@ -1468,7 +1472,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
 
         O << " .align " << Arg.getParamAlign().valueOrOne().value() << " "
           << ParamSym;
-        return;
+        continue;
       }
 
       // non-pointer scalar to kernel func
@@ -1479,7 +1483,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
       else
         O << getPTXFundamentalTypeStr(Ty);
       O << " " << ParamSym;
-      return;
+      continue;
     }
     // Non-kernel function, just print .param .b<size> for ABI
     // and .reg .b<size> for non-ABI
@@ -1492,8 +1496,14 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
     } else
       Size = Ty->getPrimitiveSizeInBits();
     O << "\t.param .b" << Size << " " << ParamSym;
-  };
-  interleave(F->args(), O, EmitParam, ",\n");
+  }
+
+  if (F->isVarArg()) {
+    if (!IsFirst)
+      O << ",\n";
+    O << "\t.param .align " << STI.getMaxRequiredAlignment() << " .b8 "
+      << TLI->getParamName(F, /* vararg */ -1) << "[]";
+  }
 
   O << "\n)";
 }

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 7555845847935..17d9f857312d6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -308,10 +308,12 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
 static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
                                LLVMContext &Ctx, CallingConv::ID CallConv,
                                Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
-                               SmallVectorImpl<TypeSize> &Offsets) {
+                               SmallVectorImpl<uint64_t> &Offsets,
+                               uint64_t StartingOffset = 0) {
   SmallVector<EVT, 16> TempVTs;
-  SmallVector<TypeSize, 16> TempOffsets;
-  ComputeValueVTs(TLI, DL, Ty, TempVTs, /*MemVTs=*/nullptr, &TempOffsets);
+  SmallVector<uint64_t, 16> TempOffsets;
+  ComputeValueVTs(TLI, DL, Ty, TempVTs, /*MemVTs=*/nullptr, &TempOffsets,
+                  StartingOffset);
 
   for (const auto [VT, Off] : zip(TempVTs, TempOffsets)) {
     MVT RegisterVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, VT);
@@ -426,9 +428,10 @@ static EVT promoteScalarIntegerPTX(const EVT VT) {
 // parameter starting at index Idx using a single vectorized op of
 // size AccessSize. If so, it returns the number of param pieces
 // covered by the vector op. Otherwise, it returns 1.
+template <typename T>
 static unsigned canMergeParamLoadStoresStartingAt(
     unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
-    const SmallVectorImpl<TypeSize> &Offsets, Align ParamAlignment) {
+    const SmallVectorImpl<T> &Offsets, Align ParamAlignment) {
 
   // Can't vectorize if param alignment is not sufficient.
   if (ParamAlignment < AccessSize)
@@ -478,10 +481,17 @@ static unsigned canMergeParamLoadStoresStartingAt(
 // of the same size as ValueVTs indicating how each piece should be
 // loaded/stored (i.e. as a scalar, or as part of a vector
 // load/store).
+template <typename T>
 static SmallVector<unsigned, 16>
 VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
-                     const SmallVectorImpl<TypeSize> &Offsets,
-                     Align ParamAlignment) {
+                     const SmallVectorImpl<T> &Offsets, Align ParamAlignment,
+                     bool IsVAArg = false) {
+  // Set vector size to match ValueVTs and mark all elements as
+  // scalars by default.
+
+  if (IsVAArg)
+    return SmallVector<unsigned>(ValueVTs.size(), 1);
+
   SmallVector<unsigned, 16> VectorInfo;
 
   const auto GetNumElts = [&](unsigned I) -> unsigned {
@@ -796,6 +806,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   // DEBUGTRAP can be lowered to PTX brkpt
   setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal);
 
+  // Support varargs.
+  setOperationAction(ISD::VASTART, MVT::Other, Custom);
+  setOperationAction(ISD::VAARG, MVT::Other, Custom);
+  setOperationAction(ISD::VACOPY, MVT::Other, Expand);
+  setOperationAction(ISD::VAEND, MVT::Other, Expand);
+
   setOperationAction({ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX},
                      {MVT::i16, MVT::i32, MVT::i64}, Legal);
   // PTX abs.s is undefined for INT_MIN, so ISD::ABS (which requires
@@ -1191,7 +1207,8 @@ SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
 
 std::string NVPTXTargetLowering::getPrototype(
     const DataLayout &DL, Type *RetTy, const ArgListTy &Args,
-    const SmallVectorImpl<ISD::OutputArg> &Outs, const CallBase &CB,
+    const SmallVectorImpl<ISD::OutputArg> &Outs,
+    std::optional<unsigned> FirstVAArg, const CallBase &CB,
     unsigned UniqueCallSite) const {
   auto PtrVT = getPointerTy(DL);
 
@@ -1232,13 +1249,20 @@ std::string NVPTXTargetLowering::getPrototype(
   }
   O << "_ (";
 
+  bool first = true;
+
+  const unsigned NumArgs = FirstVAArg.value_or(Args.size());
   auto AllOuts = ArrayRef(Outs);
-  auto MakeArg = [&](const unsigned I) {
+  for (const unsigned I : llvm::seq(NumArgs)) {
     const auto ArgOuts =
         AllOuts.take_while([I](auto O) { return O.OrigArgIndex == I; });
     AllOuts = AllOuts.drop_front(ArgOuts.size());
 
     Type *Ty = Args[I].Ty;
+    if (!first) {
+      O << ", ";
+    }
+    first = false;
 
     if (ArgOuts[0].Flags.isByVal()) {
       // Indirect calls need strict ABI alignment so we disable optimizations by
@@ -1250,33 +1274,34 @@ std::string NVPTXTargetLowering::getPrototype(
 
       O << ".param .align " << ParamByValAlign.value() << " .b8 _["
         << ArgOuts[0].Flags.getByValSize() << "]";
-      return;
-    }
-
-    if (shouldPassAsArray(Ty)) {
-      Align ParamAlign =
-          getPTXParamAlign(&CB, Ty, I + AttributeList::FirstArgIndex, DL);
-      O << ".param .align " << ParamAlign.value() << " .b8 _["
-        << DL.getTypeAllocSize(Ty) << "]";
-      return;
-    }
-    // i8 types in IR will be i16 types in SDAG
-    assert((getValueType(DL, Ty) == ArgOuts[0].VT ||
-            (getValueType(DL, Ty) == MVT::i8 && ArgOuts[0].VT == MVT::i16)) &&
-           "type mismatch between callee prototype and arguments");
-    // scalar type
-    unsigned sz = 0;
-    if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
-      sz = promoteScalarArgumentSize(ITy->getBitWidth());
-    } else if (isa<PointerType>(Ty)) {
-      sz = PtrVT.getSizeInBits();
     } else {
-      sz = Ty->getPrimitiveSizeInBits();
+      if (shouldPassAsArray(Ty)) {
+        Align ParamAlign =
+            getPTXParamAlign(&CB, Ty, I + AttributeList::FirstArgIndex, DL);
+        O << ".param .align " << ParamAlign.value() << " .b8 _["
+          << DL.getTypeAllocSize(Ty) << "]";
+        continue;
+      }
+      // i8 types in IR will be i16 types in SDAG
+      assert((getValueType(DL, Ty) == ArgOuts[0].VT ||
+              (getValueType(DL, Ty) == MVT::i8 && ArgOuts[0].VT == MVT::i16)) &&
+             "type mismatch between callee prototype and arguments");
+      // scalar type
+      unsigned sz = 0;
+      if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
+        sz = promoteScalarArgumentSize(ITy->getBitWidth());
+      } else if (isa<PointerType>(Ty)) {
+        sz = PtrVT.getSizeInBits();
+      } else {
+        sz = Ty->getPrimitiveSizeInBits();
+      }
+      O << ".param .b" << sz << " _";
     }
-    O << ".param .b" << sz << " _";
-  };
-  interleave(seq(Args.size()), O, MakeArg, ", ");
+  }
 
+  if (FirstVAArg)
+    O << (first ? "" : ",") << " .param .align "
+      << STI.getMaxRequiredAlignment() << " .b8 _[]";
   O << ")";
   if (shouldEmitPTXNoReturn(&CB, *nvTM))
     O << " .noreturn";
@@ -1335,7 +1360,10 @@ static SDValue correctParamType(SDValue V, EVT ExpectedVT,
 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                                        SmallVectorImpl<SDValue> &InVals) const {
 
-  assert(!CLI.IsVarArg && "Vararg functions lowered in ExpandVariadics");
+  if (CLI.IsVarArg && (STI.getPTXVersion() < 60 || STI.getSmVersion() < 30))
+    report_fatal_error(
+        "Support for variadic functions (unsized array parameter) introduced "
+        "in PTX ISA version 6.0 and requires target sm_30.");
 
   SelectionDAG &DAG = CLI.DAG;
   SDLoc dl = CLI.DL;
@@ -1381,11 +1409,32 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     return Declare;
   };
 
-  // For each argument, we declare a param scalar or a param byte array in the
-  // .param space, and store the argument value to that param scalar or array
-  // starting at offset 0.
-  assert(CLI.Args.size() == CLI.NumFixedArgs &&
-         "function with extra arguments");
+  // Variadic arguments.
+  //
+  // Normally, for each argument, we declare a param scalar or a param
+  // byte array in the .param space, and store the argument value to that
+  // param scalar or array starting at offset 0.
+  //
+  // In the case of the first variadic argument, we declare a vararg byte array
+  // with size 0. The exact size of this array isn't known at this point, so
+  // it'll be patched later. All the variadic arguments will be stored to this
+  // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
+  // initially set to 0, so it can be used for non-variadic arguments (which use
+  // 0 offset) to simplify the code.
+  //
+  // After all vararg is processed, 'VAOffset' holds the size of the
+  // vararg byte array.
+  assert((CLI.IsVarArg || CLI.Args.size() == CLI.NumFixedArgs) &&
+         "Non-VarArg function with extra arguments");
+
+  const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
+  unsigned VAOffset = 0; // current offset in the param array
+
+  const SDValue VADeclareParam =
+      CLI.Args.size() > FirstVAArg
+          ? MakeDeclareArrayParam(getCallParamSymbol(DAG, FirstVAArg, MVT::i32),
+                                  Align(STI.getMaxRequiredAlignment()), 0)
+          : SDValue();
 
   // Args.size() and Outs.size() need not match.
   // Outs.size() will be larger
@@ -1411,9 +1460,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     AllOuts = AllOuts.drop_front(ArgOuts.size());
     AllOutVals = AllOutVals.drop_front(ArgOuts.size());
 
+    const bool IsVAArg = (ArgI >= FirstVAArg);
     const bool IsByVal = Arg.IsByVal;
 
-    const SDValue ParamSymbol = getCallParamSymbol(DAG, ArgI, MVT::i32);
+    const SDValue ParamSymbol =
+        getCallParamSymbol(DAG, IsVAArg ? FirstVAArg : ArgI, MVT::i32);
 
     assert((!IsByVal || Arg.IndirectType) &&
            "byval arg must have indirect type");
@@ -1437,6 +1488,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
            "type size mismatch");
 
     const SDValue ArgDeclare = [&]() {
+      if (IsVAArg)
+        return VADeclareParam;
+
       if (IsByVal || shouldPassAsArray(Arg.Ty))
         return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TySize);
 
@@ -1453,12 +1507,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       const auto PointerInfo = refinePtrAS(SrcPtr, DAG, DL, *this);
       const Align BaseSrcAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
 
+      if (IsVAArg)
+        VAOffset = alignTo(VAOffset, ArgAlign);
+
       SmallVector<EVT, 4> ValueVTs, MemVTs;
       SmallVector<TypeSize, 4> Offsets;
       ComputeValueVTs(*this, DL, ETy, ValueVTs, &MemVTs, &Offsets);
 
       unsigned J = 0;
-      const auto VI = VectorizePTXValueVTs(MemVTs, Offsets, ArgAlign);
+      const auto VI = VectorizePTXValueVTs(MemVTs, Offsets, ArgAlign, IsVAArg);
       for (const unsigned NumElts : VI) {
         EVT LoadVT = getVectorizedVT(MemVTs[J], NumElts, Ctx);
         Align SrcAlign = commonAlignment(BaseSrcAlign, Offsets[J]);
@@ -1466,8 +1523,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         SDValue SrcLoad =
             DAG.getLoad(LoadVT, dl, CallChain, SrcAddr, PointerInfo, SrcAlign);
 
-        Align ParamAlign = commonAlignment(ArgAlign, Offsets[J]);
-        SDValue ParamAddr = DAG.getObjectPtrOffset(dl, ParamSymbol, Offsets[J]);
+        TypeSize ParamOffset = Offsets[J].getWithIncrement(VAOffset);
+        Align ParamAlign = commonAlignment(ArgAlign, ParamOffset);
+        SDValue ParamAddr =
+            DAG.getObjectPtrOffset(dl, ParamSymbol, ParamOffset);
         SDValue StoreParam = DAG.getStore(
             ArgDeclare, dl, SrcLoad, ParamAddr,
             MachinePointerInfo(NVPTX::AddressSpace::DeviceParam), ParamAlign);
@@ -1475,10 +1534,13 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
         J += NumElts;
       }
+      if (IsVAArg)
+        VAOffset += TySize;
     } else {
       SmallVector<EVT, 16> VTs;
-      SmallVector<TypeSize, 16> Offsets;
-      ComputePTXValueVTs(*this, DL, Ctx, CLI.CallConv, Arg.Ty, VTs, Offsets);
+      SmallVector<uint64_t, 16> Offsets;
+      ComputePTXValueVTs(*this, DL, Ctx, CLI.CallConv, Arg.Ty, VTs, Offsets,
+                         VAOffset);
       assert(VTs.size() == Offsets.size() && "Size mismatch");
       assert(VTs.size() == ArgOuts.size() && "Size mismatch");
 
@@ -1503,11 +1565,30 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       };
 
       unsigned J = 0;
-      const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
+      const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
       for (const unsigned NumElts : VI) {
-        TypeSize Offset = Offsets[J];
+        const EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
+
+        unsigned Offset;
+        if (IsVAArg) {
+          // TODO: We may need to support vector types that can be passed
+          // as scalars in variadic arguments.
+          assert(NumElts == 1 &&
+                 "Vectorization should be disabled for vaargs.");
+
+          // Align each part of the variadic argument to their type.
+          VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
+          Offset = VAOffset;
+
+          const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
+          VAOffset += DL.getTypeAllocSize(TheStoreType.getTypeForEVT(Ctx));
+        } else {
+          assert(VAOffset == 0 && "VAOffset must be 0 for non-VA args");
+          Offset = Offsets[J];
+        }
 
-        SDValue Ptr = DAG.getObjectPtrOffset(dl, ParamSymbol, Offset);
+        SDValue Ptr =
+            DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset));
 
         const MaybeAlign CurrentAlign = ExtendIntegerParam
                                             ? MaybeAlign(std::nullopt)
@@ -1541,6 +1622,17 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     }
   }
 
+  // Set the size of the vararg param byte array if the callee is a variadic
+  // function and the variadic part is not empty.
+  if (VADeclareParam) {
+    SDValue DeclareParamOps[] = {VADeclareParam.getOperand(0),
+                                 VADeclareParam.getOperand(1),
+                                 VADeclareParam.getOperand(2), GetI32(VAOffset),
+                                 VADeclareParam.getOperand(4)};
+    DAG.MorphNodeTo(VADeclareParam.getNode(), VADeclareParam.getOpcode(),
+                    VADeclareParam->getVTList(), DeclareParamOps);
+  }
+
   const auto *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
   const auto *CalleeF = Func ? dyn_cast<Function>(Func->getGlobal()) : nullptr;
 
@@ -1574,8 +1666,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     // instruction.
     // The prototype is embedded in a string and put as the operand for a
     // CallPrototype SDNode which will print out to the value of the string.
+    const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
     std::string Proto =
-        getPrototype(DL, RetTy, Args, CLI.Outs, *CB, UniqueCallSite);
+        getPrototype(DL, RetTy, Args, CLI.Outs,
+                     HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, *CB,
+                     UniqueCallSite);
     const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
     const SDValue PrototypeDeclare = DAG.getNode(
         NVPTXISD::CallPrototype, dl, MVT::Other,
@@ -1609,7 +1704,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   SmallVector<SDValue, 16> ProxyRegOps;
   if (!Ins.empty()) {
     SmallVector<EVT, 16> VTs;
-    SmallVector<TypeSize, 16> Offsets;
+    SmallVector<uint64_t, 16> Offsets;
     ComputePTXValueVTs(*this, DL, Ctx, CLI.CallConv, RetTy, VTs, Offsets);
     assert(VTs.size() == Ins.size() && "Bad value decomposition");
 
@@ -1634,7 +1729,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       const EVT LoadVT =
           ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
       const EVT VecVT = getVectorizedVT(LoadVT, NumElts, Ctx);
-      SDValue Ptr = DAG.getObjectPtrOffset(dl, RetSymbol, Offsets[I]);
+      SDValue Ptr =
+          DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
 
       SDValue R = DAG.getLoad(
           VecVT, dl, Call, Ptr,
@@ -3385,6 +3481,10 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerFP_ROUND(Op, DAG);
   case ISD::FP_EXTEND:
     return LowerFP_EXTEND(Op, DAG);
+  case ISD::VAARG:
+    return LowerVAARG(Op, DAG);
+  case ISD::VASTART:
+    return LowerVASTART(Op, DAG);
   case ISD::FSHL:
   case ISD::FSHR:
     return lowerFSH(Op, DAG);
@@ -3464,6 +3564,63 @@ SDValue NVPTXTargetLowering::LowerADDRSPACECAST(SDValue Op,
   return Op;
 }
 
+// This function is almost a copy of SelectionDAG::expandVAArg().
+// The only 
diff  is that this one produces loads from local address space.
+SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
+  const TargetLowering *TLI = STI.getTargetLowering();
+  SDLoc DL(Op);
+
+  SDNode *Node = Op.getNode();
+  const Value *V = cast<SrcValueSDNode>(Node->getOperand(2))->getValue();
+  EVT VT = Node->getValueType(0);
+  auto *Ty = VT.getTypeForEVT(*DAG.getContext());
+  SDValue Tmp1 = Node->getOperand(0);
+  SDValue Tmp2 = Node->getOperand(1);
+  const MaybeAlign MA(Node->getConstantOperandVal(3));
+
+  SDValue VAListLoad = DAG.getLoad(TLI->getPointerTy(DAG.getDataLayout()), DL,
+                                   Tmp1, Tmp2, MachinePointerInfo(V));
+  SDValue VAList = VAListLoad;
+
+  if (MA && *MA > TLI->getMinStackArgumentAlignment()) {
+    VAList = DAG.getNode(
+        ISD::ADD, DL, VAList.getValueType(), VAList,
+        DAG.getConstant(MA->value() - 1, DL, VAList.getValueType()));
+
+    VAList = DAG.getNode(ISD::AND, DL, VAList.getValueType(), VAList,
+                         DAG.getSignedConstant(-(int64_t)MA->value(), DL,
+                                               VAList.getValueType()));
+  }
+
+  // Increment the pointer, VAList, to the next vaarg
+  Tmp1 = DAG.getNode(ISD::ADD, DL, VAList.getValueType(), VAList,
+                     DAG.getConstant(DAG.getDataLayout().getTypeAllocSize(Ty),
+                                     DL, VAList.getValueType()));
+
+  // Store the incremented VAList to the legalized pointer
+  Tmp1 = DAG.getStore(VAListLoad.getValue(1), DL, Tmp1, Tmp2,
+                      MachinePointerInfo(V));
+
+  const Value *SrcV = Constant::getNullValue(
+      PointerType::get(*DAG.getContext(), ADDRESS_SPACE_LOCAL));
+
+  // Load the actual argument out of the pointer VAList
+  return DAG.getLoad(VT, DL, Tmp1, VAList, MachinePointerInfo(SrcV));
+}
+
+SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
+  const TargetLowering *TLI = STI.getTargetLowering();
+  SDLoc DL(Op);
+  EVT PtrVT = TLI->getPointerTy(DAG.getDataLayout());
+
+  // Store the address of unsized array <function>_vararg[] in the ap object.
+  SDValue VAReg = getParamSymbol(DAG, /* vararg */ -1, PtrVT);
+
+  const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
+  return DAG.getStore(Op.getOperand(0), DL, VAReg, Op.getOperand(1),
+                      MachinePointerInfo(SV));
+}
+
 static std::pair<MemSDNode *, uint32_t>
 convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG,
                                     const NVPTXSubtarget &STI) {
@@ -3884,14 +4041,16 @@ bool NVPTXTargetLowering::splitValueIntoRegisterParts(
 
 // This creates target external symbol for a function parameter.
 // Name of the symbol is composed from its index and the function name.
-SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, unsigned I,
+// Negative index corresponds to special parameter (unsized array) used for
+// passing variable arguments.
+SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int I,
                                             EVT T) const {
   StringRef SavedStr = nvTM->getStrPool().save(
       getParamName(&DAG.getMachineFunction().getFunction(), I));
   return DAG.getExternalSymbol(SavedStr.data(), T);
 }
 
-SDValue NVPTXTargetLowering::getCallParamSymbol(SelectionDAG &DAG, unsigned I,
+SDValue NVPTXTargetLowering::getCallParamSymbol(SelectionDAG &DAG, int I,
                                                 EVT T) const {
   const StringRef SavedStr = nvTM->getStrPool().save("param" + Twine(I));
   return DAG.getExternalSymbol(SavedStr.data(), T);
@@ -3901,8 +4060,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
     const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
     SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
-  assert(!isVarArg && "Vararg functions lowered in ExpandVariadics");
-
   const DataLayout &DL = DAG.getDataLayout();
   LLVMContext &Ctx = *DAG.getContext();
   auto PtrVT = getPointerTy(DAG.getDataLayout());
@@ -3978,7 +4135,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
       InVals.push_back(P);
     } else {
       SmallVector<EVT, 16> VTs;
-      SmallVector<TypeSize, 16> Offsets;
+      SmallVector<uint64_t, 16> Offsets;
       ComputePTXValueVTs(*this, DL, Ctx, CallConv, Ty, VTs, Offsets);
       assert(VTs.size() == ArgIns.size() && "Size mismatch");
       assert(VTs.size() == Offsets.size() && "Size mismatch");
@@ -3993,7 +4150,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
         const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
         const EVT VecVT = getVectorizedVT(LoadVT, NumElts, Ctx);
 
-        SDValue VecAddr = DAG.getObjectPtrOffset(dl, ArgSymbol, Offsets[I]);
+        SDValue VecAddr = DAG.getObjectPtrOffset(
+            dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
 
         const Align PartAlign = commonAlignment(ArgAlign, Offsets[I]);
         const unsigned AS = IsKernel ? NVPTX::AddressSpace::EntryParam
@@ -4027,8 +4185,6 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
                                  const SmallVectorImpl<SDValue> &OutVals,
                                  const SDLoc &dl, SelectionDAG &DAG) const {
-  assert(!isVarArg && "Vararg functions lowered in ExpandVariadics");
-
   const Function &F = DAG.getMachineFunction().getFunction();
   Type *RetTy = F.getReturnType();
 
@@ -4051,7 +4207,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
       RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
 
   SmallVector<EVT, 16> VTs;
-  SmallVector<TypeSize, 16> Offsets;
+  SmallVector<uint64_t, 16> Offsets;
   ComputePTXValueVTs(*this, DL, Ctx, CallConv, RetTy, VTs, Offsets);
   assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
 
@@ -4077,7 +4233,8 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
     SDValue Val = getBuildVectorizedValue(
         NumElts, dl, DAG, [&](unsigned K) { return GetRetVal(I + K); });
 
-    SDValue Ptr = DAG.getObjectPtrOffset(dl, RetSymbol, Offsets[I]);
+    SDValue Ptr =
+        DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
 
     Chain = DAG.getStore(Chain, dl, Val, Ptr,
                          MachinePointerInfo(NVPTX::AddressSpace::DeviceParam),
@@ -5384,11 +5541,20 @@ void NVPTXTargetLowering::getTgtMemIntrinsic(
 }
 
 // Helper for getting a function parameter name. Name is composed from
-// its index and the function name.
+// its index and the function name. Negative index corresponds to special
+// parameter (unsized array) used for passing variable arguments.
 std::string NVPTXTargetLowering::getParamName(const Function *F,
-                                              unsigned Idx) const {
-  return (getTargetMachine().getSymbol(F)->getName() + "_param_" + Twine(Idx))
-      .str();
+                                              int Idx) const {
+  std::string ParamName;
+  raw_string_ostream ParamStr(ParamName);
+
+  ParamStr << getTargetMachine().getSymbol(F)->getName();
+  if (Idx < 0)
+    ParamStr << "_vararg";
+  else
+    ParamStr << "_param_" << Idx;
+
+  return ParamName;
 }
 
 /// isLegalAddressingMode - Return true if the addressing mode represented

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 76892c229a842..0e8dd6056af81 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -37,8 +37,9 @@ class NVPTXTargetLowering : public TargetLowering {
                           unsigned Intrinsic) const override;
 
   // Helper for getting a function parameter name. Name is composed from
-  // its index and the function name.
-  std::string getParamName(const Function *F, unsigned Idx) const;
+  // its index and the function name. Negative index corresponds to special
+  // parameter (unsized array) used for passing variable arguments.
+  std::string getParamName(const Function *F, int Idx) const;
 
   /// isLegalAddressingMode - Return true if the addressing mode represented
   /// by AM is legal for this target, for a load/store of the specified type
@@ -84,6 +85,7 @@ class NVPTXTargetLowering : public TargetLowering {
 
   std::string getPrototype(const DataLayout &DL, Type *, const ArgListTy &,
                            const SmallVectorImpl<ISD::OutputArg> &,
+                           std::optional<unsigned> FirstVAArg,
                            const CallBase &CB, unsigned UniqueCallSite) const;
 
   SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
@@ -192,8 +194,8 @@ class NVPTXTargetLowering : public TargetLowering {
   const NVPTXSubtarget &STI; // cache the subtarget here
   mutable unsigned GlobalUniqueCallSite;
 
-  SDValue getParamSymbol(SelectionDAG &DAG, unsigned I, EVT T) const;
-  SDValue getCallParamSymbol(SelectionDAG &DAG, unsigned I, EVT T) const;
+  SDValue getParamSymbol(SelectionDAG &DAG, int I, EVT T) const;
+  SDValue getCallParamSymbol(SelectionDAG &DAG, int I, EVT T) const;
   SDValue LowerADDRSPACECAST(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const;
 
@@ -226,6 +228,9 @@ class NVPTXTargetLowering : public TargetLowering {
   SDValue LowerShiftRightParts(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerCopyToReg_128(SDValue Op, SelectionDAG &DAG) const;
   unsigned getNumRegisters(LLVMContext &Context, EVT VT,
                            std::optional<MVT> RegisterVT) const override;

diff  --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index e632204a444d5..1df5d326f63a6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -334,6 +334,14 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
 
   bool hasNativeBF16Support(int Opcode) const;
 
+  // Get maximum value of required alignments among the supported data types.
+  // From the PTX ISA doc, section 8.2.3:
+  //  The memory consistency model relates operations executed on memory
+  //  locations with scalar data-types, which have a maximum size and alignment
+  //  of 64 bits. Memory operations with a vector data-type are modelled as a
+  //  set of equivalent memory operations with a scalar data-type, executed in
+  //  an unspecified order on the elements in the vector.
+  unsigned getMaxRequiredAlignment() const { return 8; }
   // Get the smallest cmpxchg word size that the hardware supports.
   unsigned getMinCmpXchgSizeInBits() const { return 32; }
 


        


More information about the llvm-branch-commits mailing list