[llvm] [NVPTX][NFC] Refactor and cleanup NVPTXISelLowering call lowering 2/n (PR #137666)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 28 09:39:53 PDT 2025


https://github.com/AlexMaclean created https://github.com/llvm/llvm-project/pull/137666

None

>From 6b242a4994fb39cf998a99145cccd596f6b8750d Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Fri, 25 Apr 2025 21:31:00 +0000
Subject: [PATCH 1/3] [NVPTX][NFC] Refactor and cleanup NVPTXISelLowering 2/n

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 475 +++++++++-----------
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td     |  33 +-
 2 files changed, 220 insertions(+), 288 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index c41741ed10232..b287822e61db9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -343,33 +343,35 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
 /// and promote them to a larger size if they're not.
 ///
 /// The promoted type is placed in \p PromoteVT if the function returns true.
-static bool PromoteScalarIntegerPTX(const EVT &VT, MVT *PromotedVT) {
+static std::optional<MVT> PromoteScalarIntegerPTX(const EVT &VT) {
   if (VT.isScalarInteger()) {
+    MVT PromotedVT;
     switch (PowerOf2Ceil(VT.getFixedSizeInBits())) {
     default:
       llvm_unreachable(
           "Promotion is not suitable for scalars of size larger than 64-bits");
     case 1:
-      *PromotedVT = MVT::i1;
+      PromotedVT = MVT::i1;
       break;
     case 2:
     case 4:
     case 8:
-      *PromotedVT = MVT::i8;
+      PromotedVT = MVT::i8;
       break;
     case 16:
-      *PromotedVT = MVT::i16;
+      PromotedVT = MVT::i16;
       break;
     case 32:
-      *PromotedVT = MVT::i32;
+      PromotedVT = MVT::i32;
       break;
     case 64:
-      *PromotedVT = MVT::i64;
+      PromotedVT = MVT::i64;
       break;
     }
-    return EVT(*PromotedVT) != VT;
+    if (VT != PromotedVT)
+      return PromotedVT;
   }
-  return false;
+  return std::nullopt;
 }
 
 // Check whether we can merge loads/stores of some of the pieces of a
@@ -1451,8 +1453,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
   SelectionDAG &DAG = CLI.DAG;
   SDLoc dl = CLI.DL;
-  SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs;
-  SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
   SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
   SDValue Chain = CLI.Chain;
   SDValue Callee = CLI.Callee;
@@ -1462,6 +1462,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   const CallBase *CB = CLI.CB;
   const DataLayout &DL = DAG.getDataLayout();
 
+  const auto GetI32 = [&](const unsigned I) {
+    return DAG.getConstant(I, dl, MVT::i32);
+  };
+
   // Variadic arguments.
   //
   // Normally, for each argument, we declare a param scalar or a param
@@ -1479,7 +1483,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   // vararg byte array.
 
   SDValue VADeclareParam;                 // vararg byte array
-  unsigned FirstVAArg = CLI.NumFixedArgs; // position of the first variadic
+  const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
   unsigned VAOffset = 0;                  // current offset in the param array
 
   const unsigned UniqueCallSite = GlobalUniqueCallSite++;
@@ -1487,7 +1491,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
   SDValue InGlue = Chain.getValue(1);
 
-  unsigned ParamCount = 0;
   // Args.size() and Outs.size() need not match.
   // Outs.size() will be larger
   //   * if there is an aggregate argument with multiple fields (each field
@@ -1497,76 +1500,81 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   //     individually present in Outs.
   // So a different index should be used for indexing into Outs/OutVals.
   // See similar issue in LowerFormalArguments.
-  unsigned OIdx = 0;
+  auto AllOuts = ArrayRef(CLI.Outs);
+  auto AllOutVals = ArrayRef(CLI.OutVals);
+  assert(AllOuts.size() == AllOutVals.size() &&
+         "Outs and OutVals must be the same size");
   // Declare the .params or .reg need to pass values
   // to the function
-  for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
-    EVT VT = Outs[OIdx].VT;
-    Type *Ty = Args[i].Ty;
-    bool IsVAArg = (i >= CLI.NumFixedArgs);
-    bool IsByVal = Outs[OIdx].Flags.isByVal();
+  for (const auto [I, Arg] : llvm::enumerate(Args)) {
+    const auto ArgOuts =
+        AllOuts.take_while([I = I](auto O) { return O.OrigArgIndex == I; });
+    const auto ArgOutVals = AllOutVals.take_front(ArgOuts.size());
+    AllOuts = AllOuts.drop_front(ArgOuts.size());
+    AllOutVals = AllOutVals.drop_front(ArgOuts.size());
+
+    const bool IsVAArg = (I >= CLI.NumFixedArgs);
+    const bool IsByVal = Arg.IsByVal;
 
     SmallVector<EVT, 16> VTs;
     SmallVector<uint64_t, 16> Offsets;
 
-    assert((!IsByVal || Args[i].IndirectType) &&
+    assert((!IsByVal || Arg.IndirectType) &&
            "byval arg must have indirect type");
-    Type *ETy = (IsByVal ? Args[i].IndirectType : Ty);
+    Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
     ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset);
+    assert(VTs.size() == Offsets.size() && "Size mismatch");
+    assert((IsByVal || VTs.size() == ArgOuts.size()) && "Size mismatch");
 
     Align ArgAlign;
     if (IsByVal) {
       // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
       // so we don't need to worry whether it's naturally aligned or not.
       // See TargetLowering::LowerCallTo().
-      Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
+      Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
       ArgAlign = getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
                                             InitialAlign, DL);
       if (IsVAArg)
         VAOffset = alignTo(VAOffset, ArgAlign);
     } else {
-      ArgAlign = getArgumentAlignment(CB, Ty, ParamCount + 1, DL);
+      ArgAlign = getArgumentAlignment(CB, Arg.Ty, I + 1, DL);
     }
 
-    unsigned TypeSize =
-        (IsByVal ? Outs[OIdx].Flags.getByValSize() : DL.getTypeAllocSize(Ty));
-    SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
+    const unsigned TypeSize = DL.getTypeAllocSize(ETy);
+    assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
+           "type size mismatch");
 
     bool NeedAlign; // Does argument declaration specify alignment?
-    const bool PassAsArray = IsByVal || shouldPassAsArray(Ty);
+    const bool PassAsArray = IsByVal || shouldPassAsArray(Arg.Ty);
     if (IsVAArg) {
-      if (ParamCount == FirstVAArg) {
-        SDValue DeclareParamOps[] = {
-            Chain, DAG.getConstant(STI.getMaxRequiredAlignment(), dl, MVT::i32),
-            DAG.getConstant(ParamCount, dl, MVT::i32),
-            DAG.getConstant(1, dl, MVT::i32), InGlue};
-        VADeclareParam = Chain = DAG.getNode(NVPTXISD::DeclareParam, dl,
-                                             DeclareParamVTs, DeclareParamOps);
+      if (I == FirstVAArg) {
+        VADeclareParam = Chain =
+            DAG.getNode(NVPTXISD::DeclareParam, dl, {MVT::Other, MVT::Glue},
+                        {Chain, GetI32(STI.getMaxRequiredAlignment()),
+                         GetI32(I), GetI32(1), InGlue});
       }
       NeedAlign = PassAsArray;
     } else if (PassAsArray) {
       // declare .param .align <align> .b8 .param<n>[<size>];
-      SDValue DeclareParamOps[] = {
-          Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32),
-          DAG.getConstant(ParamCount, dl, MVT::i32),
-          DAG.getConstant(TypeSize, dl, MVT::i32), InGlue};
-      Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
-                          DeclareParamOps);
+      Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, {MVT::Other, MVT::Glue},
+                          {Chain, GetI32(ArgAlign.value()), GetI32(I),
+                           GetI32(TypeSize), InGlue});
       NeedAlign = true;
     } else {
+      assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
       // declare .param .b<size> .param<n>;
-      if (VT.isInteger() || VT.isFloatingPoint()) {
-        // PTX ABI requires integral types to be at least 32 bits in
-        // size. FP16 is loaded/stored using i16, so it's handled
-        // here as well.
-        TypeSize = promoteScalarArgumentSize(TypeSize * 8) / 8;
-      }
-      SDValue DeclareScalarParamOps[] = {
-          Chain, DAG.getConstant(ParamCount, dl, MVT::i32),
-          DAG.getConstant(TypeSize * 8, dl, MVT::i32),
-          DAG.getConstant(0, dl, MVT::i32), InGlue};
-      Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
-                          DeclareScalarParamOps);
+
+      // PTX ABI requires integral types to be at least 32 bits in
+      // size. FP16 is loaded/stored using i16, so it's handled
+      // here as well.
+      const unsigned PromotedSize =
+          (ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint())
+              ? promoteScalarArgumentSize(TypeSize * 8)
+              : TypeSize * 8;
+
+      Chain = DAG.getNode(
+          NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
+          {Chain, GetI32(I), GetI32(PromotedSize), GetI32(0), InGlue});
       NeedAlign = false;
     }
     InGlue = Chain.getValue(1);
@@ -1575,8 +1583,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     // than 32-bits are sign extended or zero extended, depending on
     // whether they are signed or unsigned types. This case applies
     // only to scalar parameters and not to aggregate values.
-    bool ExtendIntegerParam =
-        Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32;
+    const bool ExtendIntegerParam =
+        Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
 
     auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
     SmallVector<SDValue, 6> StoreOperands;
@@ -1587,34 +1595,34 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       if (NeedAlign)
         PartAlign = commonAlignment(ArgAlign, CurOffset);
 
-      SDValue StVal = OutVals[OIdx];
-
-      MVT PromotedVT;
-      if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) {
-        EltVT = EVT(PromotedVT);
-      }
-      if (PromoteScalarIntegerPTX(StVal.getValueType(), &PromotedVT)) {
-        llvm::ISD::NodeType Ext =
-            Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
-        StVal = DAG.getNode(Ext, dl, PromotedVT, StVal);
-      }
+      if (auto PromotedVT = PromoteScalarIntegerPTX(EltVT))
+        EltVT = *PromotedVT;
 
+      SDValue StVal;
       if (IsByVal) {
-        auto MPI = refinePtrAS(StVal, DAG, DL, *this);
-        const EVT PtrVT = StVal.getValueType();
-        SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
-                                      DAG.getConstant(CurOffset, dl, PtrVT));
+        SDValue Ptr = ArgOutVals[0];
+        auto MPI = refinePtrAS(Ptr, DAG, DL, *this);
+        SDValue SrcAddr =
+            DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(CurOffset));
 
         StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign);
-      } else if (ExtendIntegerParam) {
+      } else {
+        StVal = ArgOutVals[J];
+
+        if (auto PromotedVT = PromoteScalarIntegerPTX(StVal.getValueType())) {
+          llvm::ISD::NodeType Ext =
+              ArgOuts[J].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+          StVal = DAG.getNode(Ext, dl, *PromotedVT, StVal);
+        }
+      }
+
+      if (ExtendIntegerParam) {
         assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
         // zext/sext to i32
-        StVal = DAG.getNode(Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
+        StVal = DAG.getNode(ArgOuts[J].Flags.isSExt() ? ISD::SIGN_EXTEND
                                                       : ISD::ZERO_EXTEND,
                             dl, MVT::i32, StVal);
-      }
-
-      if (!ExtendIntegerParam && EltVT.getSizeInBits() < 16) {
+      } else if (EltVT.getSizeInBits() < 16) {
         // Use 16-bit registers for small stores as it's the
         // smallest general purpose register size supported by NVPTX.
         StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
@@ -1623,36 +1631,28 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
       // scalar store. In such cases, fall back to byte stores.
       if (VectorInfo[J] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() &&
-          PartAlign.value() <
-              DL.getABITypeAlign(EltVT.getTypeForEVT(*DAG.getContext()))) {
+          PartAlign.value() < DAG.getEVTAlign(EltVT)) {
         assert(StoreOperands.empty() && "Unfinished preceeding store.");
         Chain = LowerUnalignedStoreParam(
             DAG, Chain, IsByVal ? CurOffset + VAOffset : CurOffset, EltVT,
-            StVal, InGlue, ParamCount, dl);
+            StVal, InGlue, I, dl);
 
         // LowerUnalignedStoreParam took care of inserting the necessary nodes
         // into the SDAG, so just move on to the next element.
-        if (!IsByVal)
-          ++OIdx;
         continue;
       }
 
       // New store.
       if (VectorInfo[J] & PVF_FIRST) {
-        assert(StoreOperands.empty() && "Unfinished preceding store.");
-        StoreOperands.push_back(Chain);
-        StoreOperands.push_back(
-            DAG.getConstant(IsVAArg ? FirstVAArg : ParamCount, dl, MVT::i32));
-
-        if (!IsByVal && IsVAArg) {
+        if (!IsByVal && IsVAArg)
           // Align each part of the variadic argument to their type.
-          VAOffset = alignTo(VAOffset, DL.getABITypeAlign(EltVT.getTypeForEVT(
-                                           *DAG.getContext())));
-        }
+          VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
 
-        StoreOperands.push_back(DAG.getConstant(
-            IsByVal ? CurOffset + VAOffset : (IsVAArg ? VAOffset : CurOffset),
-            dl, MVT::i32));
+        assert(StoreOperands.empty() && "Unfinished preceding store.");
+        StoreOperands.append(
+            {Chain, GetI32(IsVAArg ? FirstVAArg : I),
+             GetI32(IsByVal ? CurOffset + VAOffset
+                            : (IsVAArg ? VAOffset : CurOffset))});
       }
 
       // Record the value to store.
@@ -1699,13 +1699,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
               TheStoreType.getTypeForEVT(*DAG.getContext()));
         }
       }
-      if (!IsByVal)
-        ++OIdx;
     }
     assert(StoreOperands.empty() && "Unfinished parameter store.");
-    if (!IsByVal && VTs.size() > 0)
-      --OIdx;
-    ++ParamCount;
     if (IsByVal && IsVAArg)
       VAOffset += TypeSize;
   }
@@ -1714,7 +1709,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   MaybeAlign retAlignment = std::nullopt;
 
   // Handle Result
-  if (Ins.size() > 0) {
+  if (!Ins.empty()) {
     SmallVector<EVT, 16> resvtparts;
     ComputeValueVTs(*this, DL, RetTy, resvtparts);
 
@@ -1724,47 +1719,42 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
     if (!shouldPassAsArray(RetTy)) {
       resultsz = promoteScalarArgumentSize(resultsz);
-      SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-      SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
-                                  DAG.getConstant(resultsz, dl, MVT::i32),
-                                  DAG.getConstant(0, dl, MVT::i32), InGlue };
-      Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
+      SDValue DeclareRetOps[] = {Chain, GetI32(1), GetI32(resultsz), GetI32(0),
+                                 InGlue};
+      Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, {MVT::Other, MVT::Glue},
                           DeclareRetOps);
       InGlue = Chain.getValue(1);
     } else {
       retAlignment = getArgumentAlignment(CB, RetTy, 0, DL);
       assert(retAlignment && "retAlignment is guaranteed to be set");
-      SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-      SDValue DeclareRetOps[] = {
-          Chain, DAG.getConstant(retAlignment->value(), dl, MVT::i32),
-          DAG.getConstant(resultsz / 8, dl, MVT::i32),
-          DAG.getConstant(0, dl, MVT::i32), InGlue};
-      Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
-                          DeclareRetOps);
+      SDValue DeclareRetOps[] = {Chain, GetI32(retAlignment->value()),
+                                 GetI32(resultsz / 8), GetI32(0), InGlue};
+      Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl,
+                          {MVT::Other, MVT::Glue}, DeclareRetOps);
       InGlue = Chain.getValue(1);
     }
   }
 
-  bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
+  const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
   // Set the size of the vararg param byte array if the callee is a variadic
   // function and the variadic part is not empty.
   if (HasVAArgs) {
-    SDValue DeclareParamOps[] = {
-        VADeclareParam.getOperand(0), VADeclareParam.getOperand(1),
-        VADeclareParam.getOperand(2), DAG.getConstant(VAOffset, dl, MVT::i32),
-        VADeclareParam.getOperand(4)};
+    SDValue DeclareParamOps[] = {VADeclareParam.getOperand(0),
+                                 VADeclareParam.getOperand(1),
+                                 VADeclareParam.getOperand(2), GetI32(VAOffset),
+                                 VADeclareParam.getOperand(4)};
     DAG.MorphNodeTo(VADeclareParam.getNode(), VADeclareParam.getOpcode(),
                     VADeclareParam->getVTList(), DeclareParamOps);
   }
 
   // If the type of the callsite does not match that of the function, convert
   // the callsite to an indirect call.
-  bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func);
+  const bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func);
 
   // Both indirect calls and libcalls have nullptr Func. In order to distinguish
   // between them we must rely on the call site value which is valid for
   // indirect calls but is always null for libcalls.
-  bool isIndirectCall = (!Func && CB) || ConvertToIndirectCall;
+  const bool IsIndirectCall = (!Func && CB) || ConvertToIndirectCall;
 
   if (isa<ExternalSymbolSDNode>(Callee)) {
     Function* CalleeFunc = nullptr;
@@ -1778,7 +1768,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     CalleeFunc->addFnAttr("nvptx-libcall-callee", "true");
   }
 
-  if (isIndirectCall) {
+  if (IsIndirectCall) {
     // This is indirect function call case : PTX requires a prototype of the
     // form
     // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
@@ -1786,9 +1776,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     // instruction.
     // The prototype is embedded in a string and put as the operand for a
     // CallPrototype SDNode which will print out to the value of the string.
-    SDVTList ProtoVTs = DAG.getVTList(MVT::Other, MVT::Glue);
     std::string Proto = getPrototype(
-        DL, RetTy, Args, Outs, retAlignment,
+        DL, RetTy, Args, CLI.Outs, retAlignment,
         HasVAArgs
             ? std::optional<std::pair<unsigned, const APInt &>>(std::make_pair(
                   CLI.NumFixedArgs, VADeclareParam->getConstantOperandAPInt(1)))
@@ -1800,20 +1789,19 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         DAG.getTargetExternalSymbol(ProtoStr, MVT::i32),
         InGlue,
     };
-    Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, ProtoVTs, ProtoOps);
+    Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, {MVT::Other, MVT::Glue},
+                        ProtoOps);
     InGlue = Chain.getValue(1);
   }
   // Op to just print "call"
-  SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-  SDValue PrintCallOps[] = {
-    Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, dl, MVT::i32), InGlue
-  };
+  SDValue PrintCallOps[] = {Chain, GetI32(Ins.empty() ? 0 : 1), InGlue};
   // We model convergent calls as separate opcodes.
-  unsigned Opcode = isIndirectCall ? NVPTXISD::PrintCall : NVPTXISD::PrintCallUni;
+  unsigned Opcode =
+      IsIndirectCall ? NVPTXISD::PrintCall : NVPTXISD::PrintCallUni;
   if (CLI.IsConvergent)
     Opcode = Opcode == NVPTXISD::PrintCallUni ? NVPTXISD::PrintConvergentCallUni
                                               : NVPTXISD::PrintConvergentCall;
-  Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps);
+  Chain = DAG.getNode(Opcode, dl, {MVT::Other, MVT::Glue}, PrintCallOps);
   InGlue = Chain.getValue(1);
 
   if (ConvertToIndirectCall) {
@@ -1829,43 +1817,34 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   }
 
   // Ops to print out the function name
-  SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
   SDValue CallVoidOps[] = { Chain, Callee, InGlue };
-  Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps);
+  Chain =
+      DAG.getNode(NVPTXISD::CallVoid, dl, {MVT::Other, MVT::Glue}, CallVoidOps);
   InGlue = Chain.getValue(1);
 
   // Ops to print out the param list
-  SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
   SDValue CallArgBeginOps[] = { Chain, InGlue };
-  Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
+  Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, {MVT::Other, MVT::Glue},
                       CallArgBeginOps);
   InGlue = Chain.getValue(1);
 
-  for (unsigned i = 0, e = std::min(CLI.NumFixedArgs + 1, ParamCount); i != e;
-       ++i) {
-    unsigned opcode;
-    if (i == (e - 1))
-      opcode = NVPTXISD::LastCallArg;
-    else
-      opcode = NVPTXISD::CallArg;
-    SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-    SDValue CallArgOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
-                             DAG.getConstant(i, dl, MVT::i32), InGlue };
-    Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps);
+  const unsigned E = std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size());
+  for (const unsigned I : llvm::seq(E)) {
+    const unsigned Opcode =
+        I == (E - 1) ? NVPTXISD::LastCallArg : NVPTXISD::CallArg;
+    SDValue CallArgOps[] = {Chain, GetI32(1), GetI32(I), InGlue};
+    Chain = DAG.getNode(Opcode, dl, {MVT::Other, MVT::Glue}, CallArgOps);
     InGlue = Chain.getValue(1);
   }
-  SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-  SDValue CallArgEndOps[] = { Chain,
-                              DAG.getConstant(isIndirectCall ? 0 : 1, dl, MVT::i32),
-                              InGlue };
-  Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps);
+  SDValue CallArgEndOps[] = {Chain, GetI32(IsIndirectCall ? 0 : 1), InGlue};
+  Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, {MVT::Other, MVT::Glue},
+                      CallArgEndOps);
   InGlue = Chain.getValue(1);
 
-  if (isIndirectCall) {
-    SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-    SDValue PrototypeOps[] = {
-        Chain, DAG.getConstant(UniqueCallSite, dl, MVT::i32), InGlue};
-    Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps);
+  if (IsIndirectCall) {
+    SDValue PrototypeOps[] = {Chain, GetI32(UniqueCallSite), InGlue};
+    Chain = DAG.getNode(NVPTXISD::Prototype, dl, {MVT::Other, MVT::Glue},
+                        PrototypeOps);
     InGlue = Chain.getValue(1);
   }
 
@@ -1881,7 +1860,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   SmallVector<SDValue, 16> TempProxyRegOps;
 
   // Generate loads from param memory/moves from registers for result
-  if (Ins.size() > 0) {
+  if (!Ins.empty()) {
     SmallVector<EVT, 16> VTs;
     SmallVector<uint64_t, 16> Offsets;
     ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
@@ -1896,60 +1875,57 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
     // 32-bits are sign extended or zero extended, depending on whether
     // they are signed or unsigned types.
-    bool ExtendIntegerRetVal =
+    const bool ExtendIntegerRetVal =
         RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
 
-    for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
-      bool needTruncate = false;
-      EVT TheLoadType = VTs[i];
-      EVT EltType = Ins[i].VT;
-      Align EltAlign = commonAlignment(RetAlign, Offsets[i]);
-      MVT PromotedVT;
-
-      if (PromoteScalarIntegerPTX(TheLoadType, &PromotedVT)) {
-        TheLoadType = EVT(PromotedVT);
-        EltType = EVT(PromotedVT);
-        needTruncate = true;
+    for (const unsigned I : llvm::seq(VTs.size())) {
+      bool NeedTruncate = false;
+      EVT TheLoadType = VTs[I];
+      EVT EltType = Ins[I].VT;
+      Align EltAlign = commonAlignment(RetAlign, Offsets[I]);
+
+      if (auto PromotedVT = PromoteScalarIntegerPTX(TheLoadType)) {
+        TheLoadType = *PromotedVT;
+        EltType = *PromotedVT;
+        NeedTruncate = true;
       }
 
       if (ExtendIntegerRetVal) {
         TheLoadType = MVT::i32;
         EltType = MVT::i32;
-        needTruncate = true;
+        NeedTruncate = true;
       } else if (TheLoadType.getSizeInBits() < 16) {
-        if (VTs[i].isInteger())
-          needTruncate = true;
+        if (VTs[I].isInteger())
+          NeedTruncate = true;
         EltType = MVT::i16;
       }
 
       // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
       // scalar load. In such cases, fall back to byte loads.
-      if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType() &&
-          EltAlign < DL.getABITypeAlign(
-                         TheLoadType.getTypeForEVT(*DAG.getContext()))) {
+      if (VectorInfo[I] == PVF_SCALAR && RetTy->isAggregateType() &&
+          EltAlign < DAG.getEVTAlign(TheLoadType)) {
         assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
         SDValue Ret = LowerUnalignedLoadRetParam(
-            DAG, Chain, Offsets[i], TheLoadType, InGlue, TempProxyRegOps, dl);
+            DAG, Chain, Offsets[I], TheLoadType, InGlue, TempProxyRegOps, dl);
         ProxyRegOps.push_back(SDValue());
         ProxyRegTruncates.push_back(std::optional<MVT>());
-        RetElts.resize(i);
+        RetElts.resize(I);
         RetElts.push_back(Ret);
 
         continue;
       }
 
       // Record index of the very first element of the vector.
-      if (VectorInfo[i] & PVF_FIRST) {
+      if (VectorInfo[I] & PVF_FIRST) {
         assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
-        VecIdx = i;
+        VecIdx = I;
       }
 
       LoadVTs.push_back(EltType);
 
-      if (VectorInfo[i] & PVF_LAST) {
-        unsigned NumElts = LoadVTs.size();
-        LoadVTs.push_back(MVT::Other);
-        LoadVTs.push_back(MVT::Glue);
+      if (VectorInfo[I] & PVF_LAST) {
+        const unsigned NumElts = LoadVTs.size();
+        LoadVTs.append({MVT::Other, MVT::Glue});
         NVPTXISD::NodeType Op;
         switch (NumElts) {
         case 1:
@@ -1965,21 +1941,20 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
           llvm_unreachable("Invalid vector info.");
         }
 
-        SDValue LoadOperands[] = {
-            Chain, DAG.getConstant(1, dl, MVT::i32),
-            DAG.getConstant(Offsets[VecIdx], dl, MVT::i32), InGlue};
+        SDValue LoadOperands[] = {Chain, GetI32(1), GetI32(Offsets[VecIdx]),
+                                  InGlue};
         SDValue RetVal = DAG.getMemIntrinsicNode(
             Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType,
             MachinePointerInfo(), EltAlign,
             MachineMemOperand::MOLoad);
 
-        for (unsigned j = 0; j < NumElts; ++j) {
-          ProxyRegOps.push_back(RetVal.getValue(j));
+        for (const unsigned J : llvm::seq(NumElts)) {
+          ProxyRegOps.push_back(RetVal.getValue(J));
 
-          if (needTruncate)
-            ProxyRegTruncates.push_back(std::optional<MVT>(Ins[VecIdx + j].VT));
+          if (NeedTruncate)
+            ProxyRegTruncates.push_back(Ins[VecIdx + J].VT);
           else
-            ProxyRegTruncates.push_back(std::optional<MVT>());
+            ProxyRegTruncates.push_back(std::nullopt);
         }
 
         Chain = RetVal.getValue(NumElts);
@@ -1999,33 +1974,31 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   // Append ProxyReg instructions to the chain to make sure that `callseq_end`
   // will not get lost. Otherwise, during libcalls expansion, the nodes can become
   // dangling.
-  for (unsigned i = 0; i < ProxyRegOps.size(); ++i) {
-    if (i < RetElts.size() && RetElts[i]) {
-      InVals.push_back(RetElts[i]);
+  for (const unsigned I : llvm::seq(ProxyRegOps.size())) {
+    if (I < RetElts.size() && RetElts[I]) {
+      InVals.push_back(RetElts[I]);
       continue;
     }
 
     SDValue Ret = DAG.getNode(
-      NVPTXISD::ProxyReg, dl,
-      DAG.getVTList(ProxyRegOps[i].getSimpleValueType(), MVT::Other, MVT::Glue),
-      { Chain, ProxyRegOps[i], InGlue }
-    );
+        NVPTXISD::ProxyReg, dl,
+        {ProxyRegOps[I].getSimpleValueType(), MVT::Other, MVT::Glue},
+        {Chain, ProxyRegOps[I], InGlue});
 
     Chain = Ret.getValue(1);
     InGlue = Ret.getValue(2);
 
-    if (ProxyRegTruncates[i]) {
-      Ret = DAG.getNode(ISD::TRUNCATE, dl, *ProxyRegTruncates[i], Ret);
+    if (ProxyRegTruncates[I]) {
+      Ret = DAG.getNode(ISD::TRUNCATE, dl, *ProxyRegTruncates[I], Ret);
     }
 
     InVals.push_back(Ret);
   }
 
   for (SDValue &T : TempProxyRegOps) {
-    SDValue Repl = DAG.getNode(
-        NVPTXISD::ProxyReg, dl,
-        DAG.getVTList(T.getSimpleValueType(), MVT::Other, MVT::Glue),
-        {Chain, T.getOperand(0), InGlue});
+    SDValue Repl = DAG.getNode(NVPTXISD::ProxyReg, dl,
+                               {T.getSimpleValueType(), MVT::Other, MVT::Glue},
+                               {Chain, T.getOperand(0), InGlue});
     DAG.ReplaceAllUsesWith(T, Repl);
     DAG.RemoveDeadNode(T.getNode());
 
@@ -3451,29 +3424,29 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
         // That's the last element of this store op.
         if (VectorInfo[PartI] & PVF_LAST) {
           const unsigned NumElts = PartI - VecIdx + 1;
-          EVT EltVT = VTs[PartI];
-          // i1 is loaded/stored as i8.
-          EVT LoadVT = EltVT;
-          if (EltVT == MVT::i1)
-            LoadVT = MVT::i8;
-          else if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
+          const EVT EltVT = VTs[PartI];
+          const EVT LoadVT = [&]() -> EVT {
+            // i1 is loaded/stored as i8.
+            if (EltVT == MVT::i1)
+              return MVT::i8;
             // getLoad needs a vector type, but it can't handle
             // vectors which contain v2f16 or v2bf16 elements. So we must load
             // using i32 here and then bitcast back.
-            LoadVT = MVT::i32;
+            if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
+              return MVT::i32;
+            return EltVT;
+          }();
 
-          EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
-          SDValue VecAddr =
-              DAG.getNode(ISD::ADD, dl, PtrVT, ArgSymbol,
-                          DAG.getConstant(Offsets[VecIdx], dl, PtrVT));
+          const EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
+          SDValue VecAddr = DAG.getObjectPtrOffset(
+              dl, ArgSymbol, TypeSize::getFixed(Offsets[VecIdx]));
 
           const MaybeAlign PartAlign = [&]() -> MaybeAlign {
             if (aggregateIsPacked)
               return Align(1);
             if (NumElts != 1)
               return std::nullopt;
-            Align PartAlign =
-                DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext()));
+            Align PartAlign = DAG.getEVTAlign(EltVT);
             return commonAlignment(PartAlign, Offsets[PartI]);
           }();
           SDValue P =
@@ -3486,26 +3459,20 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
           for (const unsigned J : llvm::seq(NumElts)) {
             SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
                                       DAG.getIntPtrConstant(J, dl));
-            // We've loaded i1 as an i8 and now must truncate it back to i1
-            if (EltVT == MVT::i1)
-              Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
-            // v2f16 was loaded as an i32. Now we must bitcast it back.
-            Elt = DAG.getBitcast(EltVT, Elt);
-
-            // If a promoted integer type is used, truncate down to the original
-            MVT PromotedVT;
-            if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) {
-              Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
-            }
 
-            // Extend the element if necessary (e.g. an i8 is loaded
+            // Extend or truncate the element if necessary (e.g. an i8 is loaded
             // into an i16 register)
-            if (ArgIns[PartI].VT.getFixedSizeInBits() !=
-                LoadVT.getFixedSizeInBits()) {
-              assert(ArgIns[PartI].VT.isInteger() && LoadVT.isInteger() &&
+            const EVT ExpactedVT = ArgIns[PartI].VT;
+            if (ExpactedVT.getFixedSizeInBits() !=
+                Elt.getValueType().getFixedSizeInBits()) {
+              assert(ExpactedVT.isScalarInteger() &&
+                     Elt.getValueType().isScalarInteger() &&
                      "Non-integer argument type size mismatch");
               Elt = DAG.getExtOrTrunc(ArgIns[PartI].Flags.isSExt(), Elt, dl,
-                                      ArgIns[PartI].VT);
+                                      ExpactedVT);
+            } else {
+              // v2f16 was loaded as an i32. Now we must bitcast it back.
+              Elt = DAG.getBitcast(EltVT, Elt);
             }
             InVals.push_back(Elt);
           }
@@ -3561,47 +3528,37 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
   Type *RetTy = MF.getFunction().getReturnType();
 
   const DataLayout &DL = DAG.getDataLayout();
-  SmallVector<SDValue, 16> PromotedOutVals;
   SmallVector<EVT, 16> VTs;
   SmallVector<uint64_t, 16> Offsets;
   ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
   assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
 
-  for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
-    SDValue PromotedOutVal = OutVals[i];
-    MVT PromotedVT;
-    if (PromoteScalarIntegerPTX(VTs[i], &PromotedVT)) {
-      VTs[i] = EVT(PromotedVT);
-    }
-    if (PromoteScalarIntegerPTX(PromotedOutVal.getValueType(), &PromotedVT)) {
-      llvm::ISD::NodeType Ext =
-          Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
-      PromotedOutVal = DAG.getNode(Ext, dl, PromotedVT, PromotedOutVal);
-    }
-    PromotedOutVals.push_back(PromotedOutVal);
-  }
+  for (const unsigned I : llvm::seq(VTs.size()))
+    if (auto PromotedVT = PromoteScalarIntegerPTX(VTs[I]))
+      VTs[I] = *PromotedVT;
 
   auto VectorInfo = VectorizePTXValueVTs(
       VTs, Offsets,
-      RetTy->isSized() ? getFunctionParamOptimizedAlign(&F, RetTy, DL)
+      !RetTy->isVoidTy() ? getFunctionParamOptimizedAlign(&F, RetTy, DL)
                        : Align(1));
 
   // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
   // 32-bits are sign extended or zero extended, depending on whether
   // they are signed or unsigned types.
-  bool ExtendIntegerRetVal =
+  const bool ExtendIntegerRetVal =
       RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
 
   SmallVector<SDValue, 6> StoreOperands;
-  for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
-    SDValue OutVal = OutVals[i];
-    SDValue RetVal = PromotedOutVals[i];
+  for (const unsigned I : llvm::seq(VTs.size())) {
+    SDValue RetVal = OutVals[I];
+    assert(!PromoteScalarIntegerPTX(RetVal.getValueType()) &&
+           "OutVal type should always be legal");
 
     if (ExtendIntegerRetVal) {
-      RetVal = DAG.getNode(Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND
+      RetVal = DAG.getNode(Outs[I].Flags.isSExt() ? ISD::SIGN_EXTEND
                                                   : ISD::ZERO_EXTEND,
                            dl, MVT::i32, RetVal);
-    } else if (OutVal.getValueSizeInBits() < 16) {
+    } else if (RetVal.getValueSizeInBits() < 16) {
       // Use 16-bit registers for small load-stores as it's the
       // smallest general purpose register size supported by NVPTX.
       RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal);
@@ -3609,15 +3566,14 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
 
     // If we have a PVF_SCALAR entry, it may not even be sufficiently aligned
     // for a scalar store. In such cases, fall back to byte stores.
-    if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType()) {
-      EVT ElementType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
-      Align ElementTypeAlign =
-          DL.getABITypeAlign(ElementType.getTypeForEVT(RetTy->getContext()));
-      Align ElementAlign =
-          commonAlignment(DL.getABITypeAlign(RetTy), Offsets[i]);
+    if (VectorInfo[I] == PVF_SCALAR && RetTy->isAggregateType()) {
+      const EVT ElementType = ExtendIntegerRetVal ? MVT::i32 : VTs[I];
+      const Align ElementTypeAlign = DAG.getEVTAlign(ElementType);
+      const Align ElementAlign =
+          commonAlignment(DL.getABITypeAlign(RetTy), Offsets[I]);
       if (ElementAlign < ElementTypeAlign) {
         assert(StoreOperands.empty() && "Orphaned operand list.");
-        Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[i], ElementType,
+        Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[I], ElementType,
                                        RetVal, dl);
 
         // The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
@@ -3627,17 +3583,16 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
     }
 
     // New load/store. Record chain and offset operands.
-    if (VectorInfo[i] & PVF_FIRST) {
+    if (VectorInfo[I] & PVF_FIRST) {
       assert(StoreOperands.empty() && "Orphaned operand list.");
-      StoreOperands.push_back(Chain);
-      StoreOperands.push_back(DAG.getConstant(Offsets[i], dl, MVT::i32));
+      StoreOperands.append({Chain, DAG.getConstant(Offsets[I], dl, MVT::i32)});
     }
 
     // Record the value to return.
     StoreOperands.push_back(RetVal);
 
     // That's the last element of this store op.
-    if (VectorInfo[i] & PVF_LAST) {
+    if (VectorInfo[I] & PVF_LAST) {
       NVPTXISD::NodeType Op;
       unsigned NumElts = StoreOperands.size() - 2;
       switch (NumElts) {
@@ -3656,7 +3611,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
 
       // Adjust type of load/store op if we've extended the scalar
       // return value.
-      EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
+      EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[I];
       Chain = DAG.getMemIntrinsicNode(
           Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType,
           MachinePointerInfo(), Align(1), MachineMemOperand::MOStore);
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 043da14bcb236..77be311f4e496 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -2050,8 +2050,7 @@ def SDTDeclareScalarParamProfile :
 def SDTLoadParamProfile : SDTypeProfile<1, 2, [SDTCisInt<1>, SDTCisInt<2>]>;
 def SDTLoadParamV2Profile : SDTypeProfile<2, 2, [SDTCisSameAs<0, 1>, SDTCisInt<2>, SDTCisInt<3>]>;
 def SDTLoadParamV4Profile : SDTypeProfile<4, 2, [SDTCisInt<4>, SDTCisInt<5>]>;
-def SDTPrintCallProfile : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
-def SDTPrintCallUniProfile : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
+def SDTPrintCallProfile : SDTypeProfile<0, 1, [SDTCisVT<0, i32>]>;
 def SDTStoreParamProfile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>;
 def SDTStoreParamV2Profile : SDTypeProfile<0, 4, [SDTCisInt<0>, SDTCisInt<1>]>;
 def SDTStoreParamV4Profile : SDTypeProfile<0, 6, [SDTCisInt<0>, SDTCisInt<1>]>;
@@ -2095,10 +2094,10 @@ def PrintConvergentCall :
   SDNode<"NVPTXISD::PrintConvergentCall", SDTPrintCallProfile,
          [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
 def PrintCallUni :
-  SDNode<"NVPTXISD::PrintCallUni", SDTPrintCallUniProfile,
+  SDNode<"NVPTXISD::PrintCallUni", SDTPrintCallProfile,
          [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
 def PrintConvergentCallUni :
-  SDNode<"NVPTXISD::PrintConvergentCallUni", SDTPrintCallUniProfile,
+  SDNode<"NVPTXISD::PrintConvergentCallUni", SDTPrintCallProfile,
          [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
 def StoreParam :
   SDNode<"NVPTXISD::StoreParam", SDTStoreParamProfile,
@@ -2247,31 +2246,9 @@ let mayStore = true in {
 let isCall=1 in {
   multiclass CALL<string OpcStr, SDNode OpNode> {
      def PrintCallNoRetInst : NVPTXInst<(outs), (ins),
-       !strconcat(OpcStr, " "), [(OpNode (i32 0))]>;
+       OpcStr # " ", [(OpNode 0)]>;
      def PrintCallRetInst1 : NVPTXInst<(outs), (ins),
-       !strconcat(OpcStr, " (retval0), "), [(OpNode (i32 1))]>;
-     def PrintCallRetInst2 : NVPTXInst<(outs), (ins),
-       !strconcat(OpcStr, " (retval0, retval1), "), [(OpNode (i32 2))]>;
-     def PrintCallRetInst3 : NVPTXInst<(outs), (ins),
-       !strconcat(OpcStr, " (retval0, retval1, retval2), "), [(OpNode (i32 3))]>;
-     def PrintCallRetInst4 : NVPTXInst<(outs), (ins),
-       !strconcat(OpcStr, " (retval0, retval1, retval2, retval3), "),
-       [(OpNode (i32 4))]>;
-     def PrintCallRetInst5 : NVPTXInst<(outs), (ins),
-       !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4), "),
-       [(OpNode (i32 5))]>;
-     def PrintCallRetInst6 : NVPTXInst<(outs), (ins),
-       !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, "
-                            "retval5), "),
-       [(OpNode (i32 6))]>;
-     def PrintCallRetInst7 : NVPTXInst<(outs), (ins),
-       !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, "
-                            "retval5, retval6), "),
-       [(OpNode (i32 7))]>;
-     def PrintCallRetInst8 : NVPTXInst<(outs), (ins),
-       !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, "
-                            "retval5, retval6, retval7), "),
-       [(OpNode (i32 8))]>;
+       OpcStr # " (retval0), ", [(OpNode 1)]>;
   }
 }
 

>From baf403827be872dda63ab49ab0ac91470ab4397f Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Sun, 27 Apr 2025 17:57:23 +0000
Subject: [PATCH 2/3] [NVPTX][NFC] Refactor parameter vectorization

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 528 +++++++++-----------
 1 file changed, 237 insertions(+), 291 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index b287822e61db9..b8c419bee53ae 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -428,16 +428,6 @@ static unsigned CanMergeParamLoadStoresStartingAt(
   return NumElts;
 }
 
-// Flags for tracking per-element vectorization state of loads/stores
-// of a flattened function parameter or return value.
-enum ParamVectorizationFlags {
-  PVF_INNER = 0x0, // Middle elements of a vector.
-  PVF_FIRST = 0x1, // First element of the vector.
-  PVF_LAST = 0x2,  // Last element of the vector.
-  // Scalar is effectively a 1-element vector.
-  PVF_SCALAR = PVF_FIRST | PVF_LAST
-};
-
 // Computes whether and how we can vectorize the loads/stores of a
 // flattened function parameter or return value.
 //
@@ -446,52 +436,39 @@ enum ParamVectorizationFlags {
 // of the same size as ValueVTs indicating how each piece should be
 // loaded/stored (i.e. as a scalar, or as part of a vector
 // load/store).
-static SmallVector<ParamVectorizationFlags, 16>
+static SmallVector<unsigned, 16>
 VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
                      const SmallVectorImpl<uint64_t> &Offsets,
                      Align ParamAlignment, bool IsVAArg = false) {
   // Set vector size to match ValueVTs and mark all elements as
   // scalars by default.
-  SmallVector<ParamVectorizationFlags, 16> VectorInfo;
-  VectorInfo.assign(ValueVTs.size(), PVF_SCALAR);
+  SmallVector<unsigned, 16> VectorInfo;
 
-  if (IsVAArg)
+  if (IsVAArg) {
+    VectorInfo.assign(ValueVTs.size(), 1);
     return VectorInfo;
+  }
 
-  // Check what we can vectorize using 128/64/32-bit accesses.
-  for (int I = 0, E = ValueVTs.size(); I != E; ++I) {
-    // Skip elements we've already processed.
-    assert(VectorInfo[I] == PVF_SCALAR && "Unexpected vector info state.");
-    for (unsigned AccessSize : {16, 8, 4, 2}) {
-      unsigned NumElts = CanMergeParamLoadStoresStartingAt(
+  const auto GetNumElts = [&](unsigned I) -> unsigned {
+    for (const unsigned AccessSize : {16, 8, 4, 2}) {
+      const unsigned NumElts = CanMergeParamLoadStoresStartingAt(
           I, AccessSize, ValueVTs, Offsets, ParamAlignment);
-      // Mark vectorized elements.
-      switch (NumElts) {
-      default:
-        llvm_unreachable("Unexpected return value");
-      case 1:
-        // Can't vectorize using this size, try next smaller size.
-        continue;
-      case 2:
-        assert(I + 1 < E && "Not enough elements.");
-        VectorInfo[I] = PVF_FIRST;
-        VectorInfo[I + 1] = PVF_LAST;
-        I += 1;
-        break;
-      case 4:
-        assert(I + 3 < E && "Not enough elements.");
-        VectorInfo[I] = PVF_FIRST;
-        VectorInfo[I + 1] = PVF_INNER;
-        VectorInfo[I + 2] = PVF_INNER;
-        VectorInfo[I + 3] = PVF_LAST;
-        I += 3;
-        break;
-      }
-      // Break out of the inner loop because we've already succeeded
-      // using largest possible AccessSize.
-      break;
+      assert((NumElts == 1 || NumElts == 2 || NumElts == 4) &&
+             "Unexpected vectorization size");
+      if (NumElts != 1)
+        return NumElts;
     }
+    return 1;
+  };
+
+  // Check what we can vectorize using 128/64/32-bit accesses.
+  for (unsigned I = 0, E = ValueVTs.size(); I != E;) {
+    const unsigned NumElts = GetNumElts(I);
+    VectorInfo.push_back(NumElts);
+    I += NumElts;
   }
+  assert(std::accumulate(VectorInfo.begin(), VectorInfo.end(), 0u) ==
+         ValueVTs.size());
   return VectorInfo;
 }
 
@@ -1513,7 +1490,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     AllOuts = AllOuts.drop_front(ArgOuts.size());
     AllOutVals = AllOutVals.drop_front(ArgOuts.size());
 
-    const bool IsVAArg = (I >= CLI.NumFixedArgs);
+    const bool IsVAArg = (I >= FirstVAArg);
     const bool IsByVal = Arg.IsByVal;
 
     SmallVector<EVT, 16> VTs;
@@ -1586,32 +1563,22 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     const bool ExtendIntegerParam =
         Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
 
-    auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
-    SmallVector<SDValue, 6> StoreOperands;
-    for (const unsigned J : llvm::seq(VTs.size())) {
-      EVT EltVT = VTs[J];
-      const int CurOffset = Offsets[J];
-      MaybeAlign PartAlign;
-      if (NeedAlign)
-        PartAlign = commonAlignment(ArgAlign, CurOffset);
-
-      if (auto PromotedVT = PromoteScalarIntegerPTX(EltVT))
-        EltVT = *PromotedVT;
-
+    const auto GetStoredValue = [&](const unsigned I, EVT EltVT,
+                                    MaybeAlign PartAlign) {
       SDValue StVal;
       if (IsByVal) {
         SDValue Ptr = ArgOutVals[0];
         auto MPI = refinePtrAS(Ptr, DAG, DL, *this);
         SDValue SrcAddr =
-            DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(CurOffset));
+            DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(Offsets[I]));
 
         StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign);
       } else {
-        StVal = ArgOutVals[J];
+        StVal = ArgOutVals[I];
 
         if (auto PromotedVT = PromoteScalarIntegerPTX(StVal.getValueType())) {
           llvm::ISD::NodeType Ext =
-              ArgOuts[J].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+              ArgOuts[I].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
           StVal = DAG.getNode(Ext, dl, *PromotedVT, StVal);
         }
       }
@@ -1619,7 +1586,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       if (ExtendIntegerParam) {
         assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
         // zext/sext to i32
-        StVal = DAG.getNode(ArgOuts[J].Flags.isSExt() ? ISD::SIGN_EXTEND
+        StVal = DAG.getNode(ArgOuts[I].Flags.isSExt() ? ISD::SIGN_EXTEND
                                                       : ISD::ZERO_EXTEND,
                             dl, MVT::i32, StVal);
       } else if (EltVT.getSizeInBits() < 16) {
@@ -1627,81 +1594,91 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         // smallest general purpose register size supported by NVPTX.
         StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
       }
+      return StVal;
+    };
+
+    const auto VectorInfo =
+        VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
+
+    unsigned J = 0;
+    for (const unsigned NumElts : VectorInfo) {
+      const int CurOffset = Offsets[J];
+      EVT EltVT = VTs[J];
+      MaybeAlign PartAlign;
+      if (NeedAlign)
+        PartAlign = commonAlignment(ArgAlign, CurOffset);
+
+      if (auto PromotedVT = PromoteScalarIntegerPTX(EltVT))
+        EltVT = *PromotedVT;
 
       // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
       // scalar store. In such cases, fall back to byte stores.
-      if (VectorInfo[J] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() &&
+      if (NumElts == 1 && !IsVAArg && PartAlign.has_value() &&
           PartAlign.value() < DAG.getEVTAlign(EltVT)) {
-        assert(StoreOperands.empty() && "Unfinished preceeding store.");
-        Chain = LowerUnalignedStoreParam(
-            DAG, Chain, IsByVal ? CurOffset + VAOffset : CurOffset, EltVT,
-            StVal, InGlue, I, dl);
+
+        SDValue StVal = GetStoredValue(J, EltVT, PartAlign);
+        Chain = LowerUnalignedStoreParam(DAG, Chain,
+                                         CurOffset + (IsByVal ? VAOffset : 0),
+                                         EltVT, StVal, InGlue, I, dl);
 
         // LowerUnalignedStoreParam took care of inserting the necessary nodes
         // into the SDAG, so just move on to the next element.
+        J++;
         continue;
       }
 
-      // New store.
-      if (VectorInfo[J] & PVF_FIRST) {
-        if (!IsByVal && IsVAArg)
-          // Align each part of the variadic argument to their type.
-          VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
-
-        assert(StoreOperands.empty() && "Unfinished preceding store.");
-        StoreOperands.append(
-            {Chain, GetI32(IsVAArg ? FirstVAArg : I),
-             GetI32(IsByVal ? CurOffset + VAOffset
-                            : (IsVAArg ? VAOffset : CurOffset))});
-      }
+      if (IsVAArg && !IsByVal)
+        // Align each part of the variadic argument to their type.
+        VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
 
-      // Record the value to store.
-      StoreOperands.push_back(StVal);
+      assert((IsVAArg || VAOffset == 0) &&
+             "VAOffset must be 0 for non-VA args");
+      SmallVector<SDValue, 6> StoreOperands{
+          Chain, GetI32(IsVAArg ? FirstVAArg : I),
+          GetI32(IsByVal ? CurOffset + VAOffset
+                         : (IsVAArg ? VAOffset : CurOffset))};
 
-      if (VectorInfo[J] & PVF_LAST) {
-        const unsigned NumElts = StoreOperands.size() - 3;
-        NVPTXISD::NodeType Op;
-        switch (NumElts) {
-        case 1:
-          Op = NVPTXISD::StoreParam;
-          break;
-        case 2:
-          Op = NVPTXISD::StoreParamV2;
-          break;
-        case 4:
-          Op = NVPTXISD::StoreParamV4;
-          break;
-        default:
-          llvm_unreachable("Invalid vector info.");
-        }
+      // Record the values to store.
+      for (const unsigned K : llvm::seq(NumElts))
+        StoreOperands.push_back(GetStoredValue(J + K, EltVT, PartAlign));
+      StoreOperands.push_back(InGlue);
 
-        StoreOperands.push_back(InGlue);
-
-        // Adjust type of the store op if we've extended the scalar
-        // return value.
-        EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
-
-        Chain = DAG.getMemIntrinsicNode(
-            Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands,
-            TheStoreType, MachinePointerInfo(), PartAlign,
-            MachineMemOperand::MOStore);
-        InGlue = Chain.getValue(1);
+      NVPTXISD::NodeType Op;
+      switch (NumElts) {
+      case 1:
+        Op = NVPTXISD::StoreParam;
+        break;
+      case 2:
+        Op = NVPTXISD::StoreParamV2;
+        break;
+      case 4:
+        Op = NVPTXISD::StoreParamV4;
+        break;
+      default:
+        llvm_unreachable("Invalid vector info.");
+      }
+      // Adjust type of the store op if we've extended the scalar
+      // return value.
+      EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
 
-        // Cleanup.
-        StoreOperands.clear();
+      Chain = DAG.getMemIntrinsicNode(
+          Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands,
+          TheStoreType, MachinePointerInfo(), PartAlign,
+          MachineMemOperand::MOStore);
+      InGlue = Chain.getValue(1);
 
-        // TODO: We may need to support vector types that can be passed
-        // as scalars in variadic arguments.
-        if (!IsByVal && IsVAArg) {
-          assert(NumElts == 1 &&
-                 "Vectorization is expected to be disabled for variadics.");
-          VAOffset += DL.getTypeAllocSize(
-              TheStoreType.getTypeForEVT(*DAG.getContext()));
-        }
+      // TODO: We may need to support vector types that can be passed
+      // as scalars in variadic arguments.
+      if (IsVAArg && !IsByVal) {
+        assert(NumElts == 1 &&
+               "Vectorization is expected to be disabled for variadics.");
+        VAOffset +=
+            DL.getTypeAllocSize(TheStoreType.getTypeForEVT(*DAG.getContext()));
       }
+
+      J += NumElts;
     }
-    assert(StoreOperands.empty() && "Unfinished parameter store.");
-    if (IsByVal && IsVAArg)
+    if (IsVAArg && IsByVal)
       VAOffset += TypeSize;
   }
 
@@ -1716,10 +1693,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     // Declare
     //  .param .align N .b8 retval0[<size-in-bytes>], or
     //  .param .b<size-in-bits> retval0
-    unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
+    const unsigned ResultSize = DL.getTypeAllocSizeInBits(RetTy);
     if (!shouldPassAsArray(RetTy)) {
-      resultsz = promoteScalarArgumentSize(resultsz);
-      SDValue DeclareRetOps[] = {Chain, GetI32(1), GetI32(resultsz), GetI32(0),
+      const unsigned PromotedResultSize = promoteScalarArgumentSize(ResultSize);
+      SDValue DeclareRetOps[] = {Chain, GetI32(1), GetI32(PromotedResultSize), GetI32(0),
                                  InGlue};
       Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, {MVT::Other, MVT::Glue},
                           DeclareRetOps);
@@ -1728,7 +1705,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       retAlignment = getArgumentAlignment(CB, RetTy, 0, DL);
       assert(retAlignment && "retAlignment is guaranteed to be set");
       SDValue DeclareRetOps[] = {Chain, GetI32(retAlignment->value()),
-                                 GetI32(resultsz / 8), GetI32(0), InGlue};
+                                 GetI32(ResultSize / 8), GetI32(0), InGlue};
       Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl,
                           {MVT::Other, MVT::Glue}, DeclareRetOps);
       InGlue = Chain.getValue(1);
@@ -1866,11 +1843,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
     assert(VTs.size() == Ins.size() && "Bad value decomposition");
 
-    Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
-    auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
-
-    SmallVector<EVT, 6> LoadVTs;
-    int VecIdx = -1; // Index of the first element of the vector.
+    const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
 
     // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
     // 32-bits are sign extended or zero extended, depending on whether
@@ -1878,11 +1851,13 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     const bool ExtendIntegerRetVal =
         RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
 
-    for (const unsigned I : llvm::seq(VTs.size())) {
+    const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
+    unsigned I = 0;
+    for (const unsigned VectorizedSize : VectorInfo) {
       bool NeedTruncate = false;
       EVT TheLoadType = VTs[I];
       EVT EltType = Ins[I].VT;
-      Align EltAlign = commonAlignment(RetAlign, Offsets[I]);
+      const Align EltAlign = commonAlignment(RetAlign, Offsets[I]);
 
       if (auto PromotedVT = PromoteScalarIntegerPTX(TheLoadType)) {
         TheLoadType = *PromotedVT;
@@ -1895,16 +1870,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         EltType = MVT::i32;
         NeedTruncate = true;
       } else if (TheLoadType.getSizeInBits() < 16) {
-        if (VTs[I].isInteger())
+        if (TheLoadType.isInteger())
           NeedTruncate = true;
         EltType = MVT::i16;
       }
 
       // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
       // scalar load. In such cases, fall back to byte loads.
-      if (VectorInfo[I] == PVF_SCALAR && RetTy->isAggregateType() &&
+      if (VectorizedSize == 1 && RetTy->isAggregateType() &&
           EltAlign < DAG.getEVTAlign(TheLoadType)) {
-        assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
         SDValue Ret = LowerUnalignedLoadRetParam(
             DAG, Chain, Offsets[I], TheLoadType, InGlue, TempProxyRegOps, dl);
         ProxyRegOps.push_back(SDValue());
@@ -1912,58 +1886,46 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         RetElts.resize(I);
         RetElts.push_back(Ret);
 
+        I++;
         continue;
       }
 
-      // Record index of the very first element of the vector.
-      if (VectorInfo[I] & PVF_FIRST) {
-        assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
-        VecIdx = I;
-      }
-
-      LoadVTs.push_back(EltType);
+      SmallVector<EVT, 6> LoadVTs(VectorizedSize, EltType);
+      LoadVTs.append({MVT::Other, MVT::Glue});
 
-      if (VectorInfo[I] & PVF_LAST) {
-        const unsigned NumElts = LoadVTs.size();
-        LoadVTs.append({MVT::Other, MVT::Glue});
-        NVPTXISD::NodeType Op;
-        switch (NumElts) {
-        case 1:
-          Op = NVPTXISD::LoadParam;
-          break;
-        case 2:
-          Op = NVPTXISD::LoadParamV2;
-          break;
-        case 4:
-          Op = NVPTXISD::LoadParamV4;
-          break;
-        default:
-          llvm_unreachable("Invalid vector info.");
-        }
+      NVPTXISD::NodeType Op;
+      switch (VectorizedSize) {
+      case 1:
+        Op = NVPTXISD::LoadParam;
+        break;
+      case 2:
+        Op = NVPTXISD::LoadParamV2;
+        break;
+      case 4:
+        Op = NVPTXISD::LoadParamV4;
+        break;
+      default:
+        llvm_unreachable("Invalid vector info.");
+      }
 
-        SDValue LoadOperands[] = {Chain, GetI32(1), GetI32(Offsets[VecIdx]),
-                                  InGlue};
-        SDValue RetVal = DAG.getMemIntrinsicNode(
-            Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType,
-            MachinePointerInfo(), EltAlign,
-            MachineMemOperand::MOLoad);
+      SDValue LoadOperands[] = {Chain, GetI32(1), GetI32(Offsets[I]), InGlue};
+      SDValue RetVal = DAG.getMemIntrinsicNode(
+          Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType,
+          MachinePointerInfo(), EltAlign, MachineMemOperand::MOLoad);
 
-        for (const unsigned J : llvm::seq(NumElts)) {
-          ProxyRegOps.push_back(RetVal.getValue(J));
+      for (const unsigned J : llvm::seq(VectorizedSize)) {
+        ProxyRegOps.push_back(RetVal.getValue(J));
 
-          if (NeedTruncate)
-            ProxyRegTruncates.push_back(Ins[VecIdx + J].VT);
-          else
-            ProxyRegTruncates.push_back(std::nullopt);
-        }
+        if (NeedTruncate)
+          ProxyRegTruncates.push_back(Ins[I + J].VT);
+        else
+          ProxyRegTruncates.push_back(std::nullopt);
+      }
 
-        Chain = RetVal.getValue(NumElts);
-        InGlue = RetVal.getValue(NumElts + 1);
+      Chain = RetVal.getValue(VectorizedSize);
+      InGlue = RetVal.getValue(VectorizedSize + 1);
 
-        // Cleanup
-        VecIdx = -1;
-        LoadVTs.clear();
-      }
+      I += VectorizedSize;
     }
   }
 
@@ -3409,77 +3371,65 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
       assert(VTs.size() == ArgIns.size() && "Size mismatch");
       assert(VTs.size() == Offsets.size() && "Size mismatch");
 
-      Align ArgAlign = getFunctionArgumentAlignment(
+      const Align ArgAlign = getFunctionArgumentAlignment(
           F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
-      auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
-      assert(VectorInfo.size() == VTs.size() && "Size mismatch");
-
-      int VecIdx = -1; // Index of the first element of the current vector.
-      for (const unsigned PartI : llvm::seq(VTs.size())) {
-        if (VectorInfo[PartI] & PVF_FIRST) {
-          assert(VecIdx == -1 && "Orphaned vector.");
-          VecIdx = PartI;
-        }
 
-        // That's the last element of this store op.
-        if (VectorInfo[PartI] & PVF_LAST) {
-          const unsigned NumElts = PartI - VecIdx + 1;
-          const EVT EltVT = VTs[PartI];
-          const EVT LoadVT = [&]() -> EVT {
-            // i1 is loaded/stored as i8.
-            if (EltVT == MVT::i1)
-              return MVT::i8;
-            // getLoad needs a vector type, but it can't handle
-            // vectors which contain v2f16 or v2bf16 elements. So we must load
-            // using i32 here and then bitcast back.
-            if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
-              return MVT::i32;
-            return EltVT;
-          }();
-
-          const EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
-          SDValue VecAddr = DAG.getObjectPtrOffset(
-              dl, ArgSymbol, TypeSize::getFixed(Offsets[VecIdx]));
-
-          const MaybeAlign PartAlign = [&]() -> MaybeAlign {
-            if (aggregateIsPacked)
-              return Align(1);
-            if (NumElts != 1)
-              return std::nullopt;
-            Align PartAlign = DAG.getEVTAlign(EltVT);
-            return commonAlignment(PartAlign, Offsets[PartI]);
-          }();
-          SDValue P =
-              DAG.getLoad(VecVT, dl, Root, VecAddr,
-                          MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign,
-                          MachineMemOperand::MODereferenceable |
-                              MachineMemOperand::MOInvariant);
-          if (P.getNode())
-            P.getNode()->setIROrder(Arg.getArgNo() + 1);
-          for (const unsigned J : llvm::seq(NumElts)) {
-            SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
-                                      DAG.getIntPtrConstant(J, dl));
-
-            // Extend or truncate the element if necessary (e.g. an i8 is loaded
-            // into an i16 register)
-            const EVT ExpactedVT = ArgIns[PartI].VT;
-            if (ExpactedVT.getFixedSizeInBits() !=
-                Elt.getValueType().getFixedSizeInBits()) {
-              assert(ExpactedVT.isScalarInteger() &&
-                     Elt.getValueType().isScalarInteger() &&
-                     "Non-integer argument type size mismatch");
-              Elt = DAG.getExtOrTrunc(ArgIns[PartI].Flags.isSExt(), Elt, dl,
-                                      ExpactedVT);
-            } else {
-              // v2f16 was loaded as an i32. Now we must bitcast it back.
-              Elt = DAG.getBitcast(EltVT, Elt);
-            }
-            InVals.push_back(Elt);
+      const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
+      unsigned I = 0;
+      for (const unsigned NumElts : VectorInfo) {
+        const EVT EltVT = VTs[I];
+        const EVT LoadVT = [&]() -> EVT {
+          // i1 is loaded/stored as i8.
+          if (EltVT == MVT::i1)
+            return MVT::i8;
+          // getLoad needs a vector type, but it can't handle
+          // vectors which contain v2f16 or v2bf16 elements. So we must load
+          // using i32 here and then bitcast back.
+          if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
+            return MVT::i32;
+          return EltVT;
+        }();
+
+        const EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
+        SDValue VecAddr = DAG.getObjectPtrOffset(
+            dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
+
+        const MaybeAlign PartAlign = [&]() -> MaybeAlign {
+          if (aggregateIsPacked)
+            return Align(1);
+          if (NumElts != 1)
+            return std::nullopt;
+          Align PartAlign = DAG.getEVTAlign(EltVT);
+          return commonAlignment(PartAlign, Offsets[I]);
+        }();
+        SDValue P =
+            DAG.getLoad(VecVT, dl, Root, VecAddr,
+                        MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign,
+                        MachineMemOperand::MODereferenceable |
+                            MachineMemOperand::MOInvariant);
+        if (P.getNode())
+          P.getNode()->setIROrder(Arg.getArgNo() + 1);
+        for (const unsigned J : llvm::seq(NumElts)) {
+          SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
+                                    DAG.getIntPtrConstant(J, dl));
+
+          // Extend or truncate the element if necessary (e.g. an i8 is loaded
+          // into an i16 register)
+          const EVT ExpactedVT = ArgIns[I + J].VT;
+          if (ExpactedVT.getFixedSizeInBits() !=
+              Elt.getValueType().getFixedSizeInBits()) {
+            assert(ExpactedVT.isScalarInteger() &&
+                   Elt.getValueType().isScalarInteger() &&
+                   "Non-integer argument type size mismatch");
+            Elt = DAG.getExtOrTrunc(ArgIns[I + J].Flags.isSExt(), Elt, dl,
+                                    ExpactedVT);
+          } else {
+            // v2f16 was loaded as an i32. Now we must bitcast it back.
+            Elt = DAG.getBitcast(EltVT, Elt);
           }
-
-          // Reset vector tracking state.
-          VecIdx = -1;
+          InVals.push_back(Elt);
         }
+        I += NumElts;
       }
     }
   }
@@ -3527,6 +3477,11 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
   const Function &F = MF.getFunction();
   Type *RetTy = MF.getFunction().getReturnType();
 
+  if (RetTy->isVoidTy()) {
+    assert(OutVals.empty() && Outs.empty() && "Return value expected for void");
+    return DAG.getNode(NVPTXISD::RET_GLUE, dl, MVT::Other, Chain);
+  }
+
   const DataLayout &DL = DAG.getDataLayout();
   SmallVector<EVT, 16> VTs;
   SmallVector<uint64_t, 16> Offsets;
@@ -3534,22 +3489,16 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
   assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
 
   for (const unsigned I : llvm::seq(VTs.size()))
-    if (auto PromotedVT = PromoteScalarIntegerPTX(VTs[I]))
+    if (const auto PromotedVT = PromoteScalarIntegerPTX(VTs[I]))
       VTs[I] = *PromotedVT;
 
-  auto VectorInfo = VectorizePTXValueVTs(
-      VTs, Offsets,
-      !RetTy->isVoidTy() ? getFunctionParamOptimizedAlign(&F, RetTy, DL)
-                       : Align(1));
-
   // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
   // 32-bits are sign extended or zero extended, depending on whether
   // they are signed or unsigned types.
   const bool ExtendIntegerRetVal =
       RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
 
-  SmallVector<SDValue, 6> StoreOperands;
-  for (const unsigned I : llvm::seq(VTs.size())) {
+  const auto GetRetVal = [&](unsigned I) -> SDValue {
     SDValue RetVal = OutVals[I];
     assert(!PromoteScalarIntegerPTX(RetVal.getValueType()) &&
            "OutVal type should always be legal");
@@ -3563,61 +3512,58 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
       // smallest general purpose register size supported by NVPTX.
       RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal);
     }
+    return RetVal;
+  };
 
-    // If we have a PVF_SCALAR entry, it may not even be sufficiently aligned
-    // for a scalar store. In such cases, fall back to byte stores.
-    if (VectorInfo[I] == PVF_SCALAR && RetTy->isAggregateType()) {
-      const EVT ElementType = ExtendIntegerRetVal ? MVT::i32 : VTs[I];
+  const auto VectorInfo = VectorizePTXValueVTs(
+      VTs, Offsets, getFunctionParamOptimizedAlign(&F, RetTy, DL));
+  unsigned I = 0;
+  for (const unsigned NumElts : VectorInfo) {
+    if (NumElts == 1 && RetTy->isAggregateType()) {
+      const EVT ElementType = VTs[I];
       const Align ElementTypeAlign = DAG.getEVTAlign(ElementType);
       const Align ElementAlign =
           commonAlignment(DL.getABITypeAlign(RetTy), Offsets[I]);
       if (ElementAlign < ElementTypeAlign) {
-        assert(StoreOperands.empty() && "Orphaned operand list.");
         Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[I], ElementType,
-                                       RetVal, dl);
+                                       GetRetVal(I), dl);
 
         // The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
         // into the graph, so just move on to the next element.
+        I++;
         continue;
       }
     }
 
-    // New load/store. Record chain and offset operands.
-    if (VectorInfo[I] & PVF_FIRST) {
-      assert(StoreOperands.empty() && "Orphaned operand list.");
-      StoreOperands.append({Chain, DAG.getConstant(Offsets[I], dl, MVT::i32)});
-    }
+    SmallVector<SDValue, 6> StoreOperands{
+        Chain, DAG.getConstant(Offsets[I], dl, MVT::i32)};
 
-    // Record the value to return.
-    StoreOperands.push_back(RetVal);
+    for (const unsigned J : llvm::seq(NumElts))
+      StoreOperands.push_back(GetRetVal(I + J));
 
-    // That's the last element of this store op.
-    if (VectorInfo[I] & PVF_LAST) {
-      NVPTXISD::NodeType Op;
-      unsigned NumElts = StoreOperands.size() - 2;
-      switch (NumElts) {
-      case 1:
-        Op = NVPTXISD::StoreRetval;
-        break;
-      case 2:
-        Op = NVPTXISD::StoreRetvalV2;
-        break;
-      case 4:
-        Op = NVPTXISD::StoreRetvalV4;
-        break;
-      default:
-        llvm_unreachable("Invalid vector info.");
-      }
-
-      // Adjust type of load/store op if we've extended the scalar
-      // return value.
-      EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[I];
-      Chain = DAG.getMemIntrinsicNode(
-          Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType,
-          MachinePointerInfo(), Align(1), MachineMemOperand::MOStore);
-      // Cleanup vector state.
-      StoreOperands.clear();
+    NVPTXISD::NodeType Op;
+    switch (NumElts) {
+    case 1:
+      Op = NVPTXISD::StoreRetval;
+      break;
+    case 2:
+      Op = NVPTXISD::StoreRetvalV2;
+      break;
+    case 4:
+      Op = NVPTXISD::StoreRetvalV4;
+      break;
+    default:
+      llvm_unreachable("Invalid vector info.");
     }
+
+    // Adjust type of load/store op if we've extended the scalar
+    // return value.
+    EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[I];
+    Chain = DAG.getMemIntrinsicNode(
+        Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType,
+        MachinePointerInfo(), Align(1), MachineMemOperand::MOStore);
+
+    I += NumElts;
   }
 
   return DAG.getNode(NVPTXISD::RET_GLUE, dl, MVT::Other, Chain);

>From 70a1de76d2b8b7c69d3199d37b502301eaf65606 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Mon, 28 Apr 2025 16:16:39 +0000
Subject: [PATCH 3/3] [NVPTX][NFC] Final misc. cleanup

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 222 +++++++++-----------
 llvm/lib/Target/NVPTX/NVPTXISelLowering.h   |  10 +-
 2 files changed, 101 insertions(+), 131 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index b8c419bee53ae..b21635f7caf04 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1144,21 +1144,24 @@ NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
 
 std::string NVPTXTargetLowering::getPrototype(
     const DataLayout &DL, Type *retTy, const ArgListTy &Args,
-    const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
-    std::optional<std::pair<unsigned, const APInt &>> VAInfo,
-    const CallBase &CB, unsigned UniqueCallSite) const {
+    const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign RetAlign,
+    std::optional<std::pair<unsigned, unsigned>> VAInfo, const CallBase &CB,
+    unsigned UniqueCallSite) const {
   auto PtrVT = getPointerTy(DL);
 
   std::string Prototype;
   raw_string_ostream O(Prototype);
   O << "prototype_" << UniqueCallSite << " : .callprototype ";
 
-  if (retTy->getTypeID() == Type::VoidTyID) {
+  if (retTy->isVoidTy()) {
     O << "()";
   } else {
     O << "(";
-    if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
-        !shouldPassAsArray(retTy)) {
+    if (shouldPassAsArray(retTy)) {
+      assert(RetAlign && "RetAlign must be set for non-void return types");
+      O << ".param .align " << RetAlign->value() << " .b8 _["
+        << DL.getTypeAllocSize(retTy) << "]";
+    } else if (retTy->isFloatingPointTy() || retTy->isIntegerTy()) {
       unsigned size = 0;
       if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
         size = ITy->getBitWidth();
@@ -1175,9 +1178,6 @@ std::string NVPTXTargetLowering::getPrototype(
       O << ".param .b" << size << " _";
     } else if (isa<PointerType>(retTy)) {
       O << ".param .b" << PtrVT.getSizeInBits() << " _";
-    } else if (shouldPassAsArray(retTy)) {
-      O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
-        << " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
     } else {
       llvm_unreachable("Unknown return type");
     }
@@ -1187,57 +1187,52 @@ std::string NVPTXTargetLowering::getPrototype(
 
   bool first = true;
 
-  unsigned NumArgs = VAInfo ? VAInfo->first : Args.size();
-  for (unsigned i = 0, OIdx = 0; i != NumArgs; ++i, ++OIdx) {
-    Type *Ty = Args[i].Ty;
+  const unsigned NumArgs = VAInfo ? VAInfo->first : Args.size();
+  auto AllOuts = ArrayRef(Outs);
+  for (const unsigned I : llvm::seq(NumArgs)) {
+    const auto ArgOuts =
+        AllOuts.take_while([I](auto O) { return O.OrigArgIndex == I; });
+    AllOuts = AllOuts.drop_front(ArgOuts.size());
+
+    Type *Ty = Args[I].Ty;
     if (!first) {
       O << ", ";
     }
     first = false;
 
-    if (!Outs[OIdx].Flags.isByVal()) {
+    if (ArgOuts[0].Flags.isByVal()) {
+      // Indirect calls need strict ABI alignment so we disable optimizations by
+      // not providing a function to optimize.
+      Type *ETy = Args[I].IndirectType;
+      Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
+      Align ParamByValAlign =
+          getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
+
+      O << ".param .align " << ParamByValAlign.value() << " .b8 _["
+        << ArgOuts[0].Flags.getByValSize() << "]";
+    } else {
       if (shouldPassAsArray(Ty)) {
         Align ParamAlign =
-            getArgumentAlignment(&CB, Ty, i + AttributeList::FirstArgIndex, DL);
-        O << ".param .align " << ParamAlign.value() << " .b8 ";
-        O << "_";
-        O << "[" << DL.getTypeAllocSize(Ty) << "]";
-        // update the index for Outs
-        SmallVector<EVT, 16> vtparts;
-        ComputeValueVTs(*this, DL, Ty, vtparts);
-        if (unsigned len = vtparts.size())
-          OIdx += len - 1;
+            getArgumentAlignment(&CB, Ty, I + AttributeList::FirstArgIndex, DL);
+        O << ".param .align " << ParamAlign.value() << " .b8 _["
+          << DL.getTypeAllocSize(Ty) << "]";
         continue;
       }
       // i8 types in IR will be i16 types in SDAG
-      assert((getValueType(DL, Ty) == Outs[OIdx].VT ||
-              (getValueType(DL, Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) &&
+      assert((getValueType(DL, Ty) == ArgOuts[0].VT ||
+              (getValueType(DL, Ty) == MVT::i8 && ArgOuts[0].VT == MVT::i16)) &&
              "type mismatch between callee prototype and arguments");
       // scalar type
       unsigned sz = 0;
-      if (isa<IntegerType>(Ty)) {
-        sz = cast<IntegerType>(Ty)->getBitWidth();
-        sz = promoteScalarArgumentSize(sz);
+      if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
+        sz = promoteScalarArgumentSize(ITy->getBitWidth());
       } else if (isa<PointerType>(Ty)) {
         sz = PtrVT.getSizeInBits();
       } else {
         sz = Ty->getPrimitiveSizeInBits();
       }
-      O << ".param .b" << sz << " ";
-      O << "_";
-      continue;
+      O << ".param .b" << sz << " _";
     }
-
-    // Indirect calls need strict ABI alignment so we disable optimizations by
-    // not providing a function to optimize.
-    Type *ETy = Args[i].IndirectType;
-    Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
-    Align ParamByValAlign =
-        getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
-
-    O << ".param .align " << ParamByValAlign.value() << " .b8 ";
-    O << "_";
-    O << "[" << Outs[OIdx].Flags.getByValSize() << "]";
   }
 
   if (VAInfo)
@@ -1420,6 +1415,10 @@ static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
   return MachinePointerInfo();
 }
 
+static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags) {
+  return Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+}
+
 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                                        SmallVectorImpl<SDValue> &InVals) const {
 
@@ -1483,14 +1482,14 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
          "Outs and OutVals must be the same size");
   // Declare the .params or .reg need to pass values
   // to the function
-  for (const auto [I, Arg] : llvm::enumerate(Args)) {
-    const auto ArgOuts =
-        AllOuts.take_while([I = I](auto O) { return O.OrigArgIndex == I; });
+  for (const auto [ArgI, Arg] : llvm::enumerate(Args)) {
+    const auto ArgOuts = AllOuts.take_while(
+        [ArgI = ArgI](auto O) { return O.OrigArgIndex == ArgI; });
     const auto ArgOutVals = AllOutVals.take_front(ArgOuts.size());
     AllOuts = AllOuts.drop_front(ArgOuts.size());
     AllOutVals = AllOutVals.drop_front(ArgOuts.size());
 
-    const bool IsVAArg = (I >= FirstVAArg);
+    const bool IsVAArg = (ArgI >= FirstVAArg);
     const bool IsByVal = Arg.IsByVal;
 
     SmallVector<EVT, 16> VTs;
@@ -1514,29 +1513,26 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       if (IsVAArg)
         VAOffset = alignTo(VAOffset, ArgAlign);
     } else {
-      ArgAlign = getArgumentAlignment(CB, Arg.Ty, I + 1, DL);
+      ArgAlign = getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
     }
 
     const unsigned TypeSize = DL.getTypeAllocSize(ETy);
     assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
            "type size mismatch");
 
-    bool NeedAlign; // Does argument declaration specify alignment?
     const bool PassAsArray = IsByVal || shouldPassAsArray(Arg.Ty);
     if (IsVAArg) {
-      if (I == FirstVAArg) {
+      if (ArgI == FirstVAArg) {
         VADeclareParam = Chain =
             DAG.getNode(NVPTXISD::DeclareParam, dl, {MVT::Other, MVT::Glue},
                         {Chain, GetI32(STI.getMaxRequiredAlignment()),
-                         GetI32(I), GetI32(1), InGlue});
+                         GetI32(ArgI), GetI32(1), InGlue});
       }
-      NeedAlign = PassAsArray;
     } else if (PassAsArray) {
       // declare .param .align <align> .b8 .param<n>[<size>];
       Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, {MVT::Other, MVT::Glue},
-                          {Chain, GetI32(ArgAlign.value()), GetI32(I),
+                          {Chain, GetI32(ArgAlign.value()), GetI32(ArgI),
                            GetI32(TypeSize), InGlue});
-      NeedAlign = true;
     } else {
       assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
       // declare .param .b<size> .param<n>;
@@ -1551,8 +1547,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
       Chain = DAG.getNode(
           NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
-          {Chain, GetI32(I), GetI32(PromotedSize), GetI32(0), InGlue});
-      NeedAlign = false;
+          {Chain, GetI32(ArgI), GetI32(PromotedSize), GetI32(0), InGlue});
     }
     InGlue = Chain.getValue(1);
 
@@ -1564,7 +1559,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
 
     const auto GetStoredValue = [&](const unsigned I, EVT EltVT,
-                                    MaybeAlign PartAlign) {
+                                    const Align PartAlign) {
       SDValue StVal;
       if (IsByVal) {
         SDValue Ptr = ArgOutVals[0];
@@ -1577,18 +1572,16 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         StVal = ArgOutVals[I];
 
         if (auto PromotedVT = PromoteScalarIntegerPTX(StVal.getValueType())) {
-          llvm::ISD::NodeType Ext =
-              ArgOuts[I].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
-          StVal = DAG.getNode(Ext, dl, *PromotedVT, StVal);
+          StVal = DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, *PromotedVT,
+                              StVal);
         }
       }
 
       if (ExtendIntegerParam) {
         assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
         // zext/sext to i32
-        StVal = DAG.getNode(ArgOuts[I].Flags.isSExt() ? ISD::SIGN_EXTEND
-                                                      : ISD::ZERO_EXTEND,
-                            dl, MVT::i32, StVal);
+        StVal =
+            DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, MVT::i32, StVal);
       } else if (EltVT.getSizeInBits() < 16) {
         // Use 16-bit registers for small stores as it's the
         // smallest general purpose register size supported by NVPTX.
@@ -1604,22 +1597,19 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     for (const unsigned NumElts : VectorInfo) {
       const int CurOffset = Offsets[J];
       EVT EltVT = VTs[J];
-      MaybeAlign PartAlign;
-      if (NeedAlign)
-        PartAlign = commonAlignment(ArgAlign, CurOffset);
+      const Align PartAlign = commonAlignment(ArgAlign, CurOffset);
 
       if (auto PromotedVT = PromoteScalarIntegerPTX(EltVT))
         EltVT = *PromotedVT;
 
       // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
       // scalar store. In such cases, fall back to byte stores.
-      if (NumElts == 1 && !IsVAArg && PartAlign.has_value() &&
-          PartAlign.value() < DAG.getEVTAlign(EltVT)) {
+      if (NumElts == 1 && !IsVAArg && PartAlign < DAG.getEVTAlign(EltVT)) {
 
         SDValue StVal = GetStoredValue(J, EltVT, PartAlign);
         Chain = LowerUnalignedStoreParam(DAG, Chain,
                                          CurOffset + (IsByVal ? VAOffset : 0),
-                                         EltVT, StVal, InGlue, I, dl);
+                                         EltVT, StVal, InGlue, ArgI, dl);
 
         // LowerUnalignedStoreParam took care of inserting the necessary nodes
         // into the SDAG, so just move on to the next element.
@@ -1634,9 +1624,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       assert((IsVAArg || VAOffset == 0) &&
              "VAOffset must be 0 for non-VA args");
       SmallVector<SDValue, 6> StoreOperands{
-          Chain, GetI32(IsVAArg ? FirstVAArg : I),
-          GetI32(IsByVal ? CurOffset + VAOffset
-                         : (IsVAArg ? VAOffset : CurOffset))};
+          Chain, GetI32(IsVAArg ? FirstVAArg : ArgI),
+          GetI32(VAOffset + ((IsVAArg && !IsByVal) ? 0 : CurOffset))};
 
       // Record the values to store.
       for (const unsigned K : llvm::seq(NumElts))
@@ -1683,12 +1672,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   }
 
   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
-  MaybeAlign retAlignment = std::nullopt;
+  MaybeAlign RetAlign = std::nullopt;
 
   // Handle Result
   if (!Ins.empty()) {
-    SmallVector<EVT, 16> resvtparts;
-    ComputeValueVTs(*this, DL, RetTy, resvtparts);
+    RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
 
     // Declare
     //  .param .align N .b8 retval0[<size-in-bytes>], or
@@ -1702,9 +1690,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                           DeclareRetOps);
       InGlue = Chain.getValue(1);
     } else {
-      retAlignment = getArgumentAlignment(CB, RetTy, 0, DL);
-      assert(retAlignment && "retAlignment is guaranteed to be set");
-      SDValue DeclareRetOps[] = {Chain, GetI32(retAlignment->value()),
+      SDValue DeclareRetOps[] = {Chain, GetI32(RetAlign->value()),
                                  GetI32(ResultSize / 8), GetI32(0), InGlue};
       Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl,
                           {MVT::Other, MVT::Glue}, DeclareRetOps);
@@ -1754,10 +1740,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     // The prototype is embedded in a string and put as the operand for a
     // CallPrototype SDNode which will print out to the value of the string.
     std::string Proto = getPrototype(
-        DL, RetTy, Args, CLI.Outs, retAlignment,
+        DL, RetTy, Args, CLI.Outs, RetAlign,
         HasVAArgs
-            ? std::optional<std::pair<unsigned, const APInt &>>(std::make_pair(
-                  CLI.NumFixedArgs, VADeclareParam->getConstantOperandAPInt(1)))
+            ? std::optional<std::pair<unsigned, unsigned>>(std::make_pair(
+                  CLI.NumFixedArgs, VADeclareParam.getConstantOperandVal(1)))
             : std::nullopt,
         *CB, UniqueCallSite);
     const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
@@ -1826,7 +1812,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   }
 
   SmallVector<SDValue, 16> ProxyRegOps;
-  SmallVector<std::optional<MVT>, 16> ProxyRegTruncates;
   // An item of the vector is filled if the element does not need a ProxyReg
   // operation on it and should be added to InVals as is. ProxyRegOps and
   // ProxyRegTruncates contain empty/none items at the same index.
@@ -1843,7 +1828,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
     assert(VTs.size() == Ins.size() && "Bad value decomposition");
 
-    const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
+    assert(RetAlign && "RetAlign is guaranteed to be set");
 
     // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
     // 32-bits are sign extended or zero extended, depending on whether
@@ -1851,27 +1836,22 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     const bool ExtendIntegerRetVal =
         RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
 
-    const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
+    const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, *RetAlign);
     unsigned I = 0;
     for (const unsigned VectorizedSize : VectorInfo) {
-      bool NeedTruncate = false;
       EVT TheLoadType = VTs[I];
       EVT EltType = Ins[I].VT;
-      const Align EltAlign = commonAlignment(RetAlign, Offsets[I]);
+      const Align EltAlign = commonAlignment(*RetAlign, Offsets[I]);
 
       if (auto PromotedVT = PromoteScalarIntegerPTX(TheLoadType)) {
         TheLoadType = *PromotedVT;
         EltType = *PromotedVT;
-        NeedTruncate = true;
       }
 
       if (ExtendIntegerRetVal) {
         TheLoadType = MVT::i32;
         EltType = MVT::i32;
-        NeedTruncate = true;
       } else if (TheLoadType.getSizeInBits() < 16) {
-        if (TheLoadType.isInteger())
-          NeedTruncate = true;
         EltType = MVT::i16;
       }
 
@@ -1882,7 +1862,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         SDValue Ret = LowerUnalignedLoadRetParam(
             DAG, Chain, Offsets[I], TheLoadType, InGlue, TempProxyRegOps, dl);
         ProxyRegOps.push_back(SDValue());
-        ProxyRegTruncates.push_back(std::optional<MVT>());
         RetElts.resize(I);
         RetElts.push_back(Ret);
 
@@ -1915,11 +1894,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
       for (const unsigned J : llvm::seq(VectorizedSize)) {
         ProxyRegOps.push_back(RetVal.getValue(J));
-
-        if (NeedTruncate)
-          ProxyRegTruncates.push_back(Ins[I + J].VT);
-        else
-          ProxyRegTruncates.push_back(std::nullopt);
       }
 
       Chain = RetVal.getValue(VectorizedSize);
@@ -1950,10 +1924,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     Chain = Ret.getValue(1);
     InGlue = Ret.getValue(2);
 
-    if (ProxyRegTruncates[I]) {
-      Ret = DAG.getNode(ISD::TRUNCATE, dl, *ProxyRegTruncates[I], Ret);
+    const EVT ExpectedVT = Ins[I].VT;
+    if (!Ret.getValueType().bitsEq(ExpectedVT)) {
+      Ret = DAG.getNode(ISD::TRUNCATE, dl, ExpectedVT, Ret);
     }
-
     InVals.push_back(Ret);
   }
 
@@ -3385,8 +3359,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
           // getLoad needs a vector type, but it can't handle
           // vectors which contain v2f16 or v2bf16 elements. So we must load
           // using i32 here and then bitcast back.
-          if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
-            return MVT::i32;
+          if (EltVT.isVector())
+            return MVT::getIntegerVT(EltVT.getFixedSizeInBits());
           return EltVT;
         }();
 
@@ -3416,13 +3390,15 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
           // Extend or truncate the element if necessary (e.g. an i8 is loaded
           // into an i16 register)
           const EVT ExpactedVT = ArgIns[I + J].VT;
-          if (ExpactedVT.getFixedSizeInBits() !=
-              Elt.getValueType().getFixedSizeInBits()) {
-            assert(ExpactedVT.isScalarInteger() &&
-                   Elt.getValueType().isScalarInteger() &&
-                   "Non-integer argument type size mismatch");
-            Elt = DAG.getExtOrTrunc(ArgIns[I + J].Flags.isSExt(), Elt, dl,
-                                    ExpactedVT);
+          assert((Elt.getValueType().bitsEq(ExpactedVT) ||
+                  (ExpactedVT.isScalarInteger() &&
+                   Elt.getValueType().isScalarInteger())) &&
+                 "Non-integer argument type size mismatch");
+          if (ExpactedVT.bitsGT(Elt.getValueType())) {
+            Elt = DAG.getNode(getExtOpcode(ArgIns[I + J].Flags), dl, ExpactedVT,
+                              Elt);
+          } else if (ExpactedVT.bitsLT(Elt.getValueType())) {
+            Elt = DAG.getNode(ISD::TRUNCATE, dl, ExpactedVT, Elt);
           } else {
             // v2f16 was loaded as an i32. Now we must bitcast it back.
             Elt = DAG.getBitcast(EltVT, Elt);
@@ -3504,9 +3480,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
            "OutVal type should always be legal");
 
     if (ExtendIntegerRetVal) {
-      RetVal = DAG.getNode(Outs[I].Flags.isSExt() ? ISD::SIGN_EXTEND
-                                                  : ISD::ZERO_EXTEND,
-                           dl, MVT::i32, RetVal);
+      RetVal = DAG.getNode(getExtOpcode(Outs[I].Flags), dl, MVT::i32, RetVal);
     } else if (RetVal.getValueSizeInBits() < 16) {
       // Use 16-bit registers for small load-stores as it's the
       // smallest general purpose register size supported by NVPTX.
@@ -3515,24 +3489,20 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
     return RetVal;
   };
 
-  const auto VectorInfo = VectorizePTXValueVTs(
-      VTs, Offsets, getFunctionParamOptimizedAlign(&F, RetTy, DL));
+  const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
+  const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
   unsigned I = 0;
   for (const unsigned NumElts : VectorInfo) {
-    if (NumElts == 1 && RetTy->isAggregateType()) {
-      const EVT ElementType = VTs[I];
-      const Align ElementTypeAlign = DAG.getEVTAlign(ElementType);
-      const Align ElementAlign =
-          commonAlignment(DL.getABITypeAlign(RetTy), Offsets[I]);
-      if (ElementAlign < ElementTypeAlign) {
-        Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[I], ElementType,
-                                       GetRetVal(I), dl);
-
-        // The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
-        // into the graph, so just move on to the next element.
-        I++;
-        continue;
-      }
+    const Align CurrentAlign = commonAlignment(RetAlign, Offsets[I]);
+    if (NumElts == 1 && RetTy->isAggregateType() &&
+        CurrentAlign < DAG.getEVTAlign(VTs[I])) {
+      Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[I], VTs[I],
+                                     GetRetVal(I), dl);
+
+      // The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
+      // into the graph, so just move on to the next element.
+      I++;
+      continue;
     }
 
     SmallVector<SDValue, 6> StoreOperands{
@@ -3561,7 +3531,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
     EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[I];
     Chain = DAG.getMemIntrinsicNode(
         Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType,
-        MachinePointerInfo(), Align(1), MachineMemOperand::MOStore);
+        MachinePointerInfo(), CurrentAlign, MachineMemOperand::MOStore);
 
     I += NumElts;
   }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 7a8bf3bf33a94..3279a4c2e74f3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -187,11 +187,11 @@ class NVPTXTargetLowering : public TargetLowering {
   SDValue LowerSTACKSAVE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerSTACKRESTORE(SDValue Op, SelectionDAG &DAG) const;
 
-  std::string
-  getPrototype(const DataLayout &DL, Type *, const ArgListTy &,
-               const SmallVectorImpl<ISD::OutputArg> &, MaybeAlign retAlignment,
-               std::optional<std::pair<unsigned, const APInt &>> VAInfo,
-               const CallBase &CB, unsigned UniqueCallSite) const;
+  std::string getPrototype(const DataLayout &DL, Type *, const ArgListTy &,
+                           const SmallVectorImpl<ISD::OutputArg> &,
+                           MaybeAlign RetAlign,
+                           std::optional<std::pair<unsigned, unsigned>> VAInfo,
+                           const CallBase &CB, unsigned UniqueCallSite) const;
 
   SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
                       const SmallVectorImpl<ISD::OutputArg> &Outs,



More information about the llvm-commits mailing list