[llvm] 66e8163 - [NVPTX] Vectorize loads when lowering of byval calls, misc. cleanup (#151070)

via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 1 15:02:06 PDT 2025


Author: Alex MacLean
Date: 2025-08-01T15:02:02-07:00
New Revision: 66e8163f53cacc704aab9d4c81f208727e37d3d0

URL: https://github.com/llvm/llvm-project/commit/66e8163f53cacc704aab9d4c81f208727e37d3d0
DIFF: https://github.com/llvm/llvm-project/commit/66e8163f53cacc704aab9d4c81f208727e37d3d0.diff

LOG: [NVPTX] Vectorize loads when lowering of byval calls, misc. cleanup (#151070)

This change rewrites LowerCall handling of byval arguments to vectorize
the loads in addition to the stores. In addition various minor NFC
updates and cleanups are made to reduce code duplication.

Added: 
    llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll

Modified: 
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
    llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
    llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
    llvm/test/CodeGen/NVPTX/param-vectorize-device.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 65d1be3a3847d..15f45a1f35e2f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -382,6 +382,54 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
   }
 }
 
+// We return an EVT that can hold N VTs
+// If the VT is a vector, the resulting EVT is a flat vector with the same
+// element type as VT's element type.
+static EVT getVectorizedVT(EVT VT, unsigned N, LLVMContext &C) {
+  if (N == 1)
+    return VT;
+
+  return VT.isVector() ? EVT::getVectorVT(C, VT.getScalarType(),
+                                          VT.getVectorNumElements() * N)
+                       : EVT::getVectorVT(C, VT, N);
+}
+
+static SDValue getExtractVectorizedValue(SDValue V, unsigned I, EVT VT,
+                                         const SDLoc &dl, SelectionDAG &DAG) {
+  if (V.getValueType() == VT) {
+    assert(I == 0 && "Index must be 0 for scalar value");
+    return V;
+  }
+
+  if (!VT.isVector())
+    return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, V,
+                       DAG.getVectorIdxConstant(I, dl));
+
+  return DAG.getNode(
+      ISD::EXTRACT_SUBVECTOR, dl, VT, V,
+      DAG.getVectorIdxConstant(I * VT.getVectorNumElements(), dl));
+}
+
+template <typename T>
+static inline SDValue getBuildVectorizedValue(unsigned N, const SDLoc &dl,
+                                              SelectionDAG &DAG, T GetElement) {
+  if (N == 1)
+    return GetElement(0);
+
+  SmallVector<SDValue, 8> Values;
+  for (const unsigned I : llvm::seq(N)) {
+    SDValue Val = GetElement(I);
+    if (Val.getValueType().isVector())
+      DAG.ExtractVectorElements(Val, Values);
+    else
+      Values.push_back(Val);
+  }
+
+  EVT VT = EVT::getVectorVT(*DAG.getContext(), Values[0].getValueType(),
+                            Values.size());
+  return DAG.getBuildVector(VT, dl, Values);
+}
+
 /// PromoteScalarIntegerPTX
 /// Used to make sure the arguments/returns are suitable for passing
 /// and promote them to a larger size if they're not.
@@ -420,9 +468,10 @@ static EVT promoteScalarIntegerPTX(const EVT VT) {
 // parameter starting at index Idx using a single vectorized op of
 // size AccessSize. If so, it returns the number of param pieces
 // covered by the vector op. Otherwise, it returns 1.
-static unsigned CanMergeParamLoadStoresStartingAt(
+template <typename T>
+static unsigned canMergeParamLoadStoresStartingAt(
     unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
-    const SmallVectorImpl<uint64_t> &Offsets, Align ParamAlignment) {
+    const SmallVectorImpl<T> &Offsets, Align ParamAlignment) {
 
   // Can't vectorize if param alignment is not sufficient.
   if (ParamAlignment < AccessSize)
@@ -472,10 +521,11 @@ static unsigned CanMergeParamLoadStoresStartingAt(
 // of the same size as ValueVTs indicating how each piece should be
 // loaded/stored (i.e. as a scalar, or as part of a vector
 // load/store).
+template <typename T>
 static SmallVector<unsigned, 16>
 VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
-                     const SmallVectorImpl<uint64_t> &Offsets,
-                     Align ParamAlignment, bool IsVAArg = false) {
+                     const SmallVectorImpl<T> &Offsets, Align ParamAlignment,
+                     bool IsVAArg = false) {
   // Set vector size to match ValueVTs and mark all elements as
   // scalars by default.
 
@@ -486,7 +536,7 @@ VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
 
   const auto GetNumElts = [&](unsigned I) -> unsigned {
     for (const unsigned AccessSize : {16, 8, 4, 2}) {
-      const unsigned NumElts = CanMergeParamLoadStoresStartingAt(
+      const unsigned NumElts = canMergeParamLoadStoresStartingAt(
           I, AccessSize, ValueVTs, Offsets, ParamAlignment);
       assert((NumElts == 1 || NumElts == 2 || NumElts == 4) &&
              "Unexpected vectorization size");
@@ -1384,6 +1434,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   Type *RetTy = CLI.RetTy;
   const CallBase *CB = CLI.CB;
   const DataLayout &DL = DAG.getDataLayout();
+  LLVMContext &Ctx = *DAG.getContext();
 
   const auto GetI32 = [&](const unsigned I) {
     return DAG.getConstant(I, dl, MVT::i32);
@@ -1476,15 +1527,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     const SDValue ParamSymbol =
         getCallParamSymbol(DAG, IsVAArg ? FirstVAArg : ArgI, MVT::i32);
 
-    SmallVector<EVT, 16> VTs;
-    SmallVector<uint64_t, 16> Offsets;
-
     assert((!IsByVal || Arg.IndirectType) &&
            "byval arg must have indirect type");
     Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
-    ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset);
-    assert(VTs.size() == Offsets.size() && "Size mismatch");
-    assert((IsByVal || VTs.size() == ArgOuts.size()) && "Size mismatch");
 
     const Align ArgAlign = [&]() {
       if (IsByVal) {
@@ -1492,17 +1537,14 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         // so we don't need to worry whether it's naturally aligned or not.
         // See TargetLowering::LowerCallTo().
         const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
-        const Align ByValAlign = getFunctionByValParamAlign(
-            CB->getCalledFunction(), ETy, InitialAlign, DL);
-        if (IsVAArg)
-          VAOffset = alignTo(VAOffset, ByValAlign);
-        return ByValAlign;
+        return getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
+                                          InitialAlign, DL);
       }
       return getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
     }();
 
-    const unsigned TypeSize = DL.getTypeAllocSize(ETy);
-    assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
+    const unsigned TySize = DL.getTypeAllocSize(ETy);
+    assert((!IsByVal || TySize == ArgOuts[0].Flags.getByValSize()) &&
            "type size mismatch");
 
     const SDValue ArgDeclare = [&]() {
@@ -1510,105 +1552,120 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         return VADeclareParam;
 
       if (IsByVal || shouldPassAsArray(Arg.Ty))
-        return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TypeSize);
+        return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TySize);
 
       assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
       assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) &&
              "Only int and float types are supported as non-array arguments");
 
-      return MakeDeclareScalarParam(ParamSymbol, TypeSize);
+      return MakeDeclareScalarParam(ParamSymbol, TySize);
     }();
 
-    // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
-    // than 32-bits are sign extended or zero extended, depending on
-    // whether they are signed or unsigned types. This case applies
-    // only to scalar parameters and not to aggregate values.
-    const bool ExtendIntegerParam =
-        Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
+    if (IsByVal) {
+      assert(ArgOutVals.size() == 1 && "We must pass only one value as byval");
+      SDValue SrcPtr = ArgOutVals[0];
+      const auto PointerInfo = refinePtrAS(SrcPtr, DAG, DL, *this);
+      const Align BaseSrcAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
 
-    const auto GetStoredValue = [&](const unsigned I, EVT EltVT,
-                                    const MaybeAlign PartAlign) {
-      if (IsByVal) {
-        SDValue Ptr = ArgOutVals[0];
-        auto MPI = refinePtrAS(Ptr, DAG, DL, *this);
-        SDValue SrcAddr =
-            DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(Offsets[I]));
-
-        return DAG.getLoad(EltVT, dl, CallChain, SrcAddr, MPI, PartAlign);
+      if (IsVAArg)
+        VAOffset = alignTo(VAOffset, ArgAlign);
+
+      SmallVector<EVT, 4> ValueVTs, MemVTs;
+      SmallVector<TypeSize, 4> Offsets;
+      ComputeValueVTs(*this, DL, ETy, ValueVTs, &MemVTs, &Offsets);
+
+      unsigned J = 0;
+      const auto VI = VectorizePTXValueVTs(MemVTs, Offsets, ArgAlign, IsVAArg);
+      for (const unsigned NumElts : VI) {
+        EVT LoadVT = getVectorizedVT(MemVTs[J], NumElts, Ctx);
+        Align SrcAlign = commonAlignment(BaseSrcAlign, Offsets[J]);
+        SDValue SrcAddr = DAG.getObjectPtrOffset(dl, SrcPtr, Offsets[J]);
+        SDValue SrcLoad =
+            DAG.getLoad(LoadVT, dl, CallChain, SrcAddr, PointerInfo, SrcAlign);
+
+        TypeSize ParamOffset = Offsets[J].getWithIncrement(VAOffset);
+        Align ParamAlign = commonAlignment(ArgAlign, ParamOffset);
+        SDValue ParamAddr =
+            DAG.getObjectPtrOffset(dl, ParamSymbol, ParamOffset);
+        SDValue StoreParam =
+            DAG.getStore(ArgDeclare, dl, SrcLoad, ParamAddr,
+                         MachinePointerInfo(ADDRESS_SPACE_PARAM), ParamAlign);
+        CallPrereqs.push_back(StoreParam);
+
+        J += NumElts;
       }
-      SDValue StVal = ArgOutVals[I];
-      assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
-                 StVal.getValueType() &&
-             "OutVal type should always be legal");
-
-      const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
-      const EVT StoreVT =
-          ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
-
-      return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl);
-    };
-
-    const auto VectorInfo =
-        VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
-
-    unsigned J = 0;
-    for (const unsigned NumElts : VectorInfo) {
-      const int CurOffset = Offsets[J];
-      const EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
-
-      if (IsVAArg && !IsByVal)
-        // Align each part of the variadic argument to their type.
-        VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
-
-      assert((IsVAArg || VAOffset == 0) &&
-             "VAOffset must be 0 for non-VA args");
+      if (IsVAArg)
+        VAOffset += TySize;
+    } else {
+      SmallVector<EVT, 16> VTs;
+      SmallVector<uint64_t, 16> Offsets;
+      ComputePTXValueVTs(*this, DL, Arg.Ty, VTs, &Offsets, VAOffset);
+      assert(VTs.size() == Offsets.size() && "Size mismatch");
+      assert(VTs.size() == ArgOuts.size() && "Size mismatch");
+
+      // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
+      // than 32-bits are sign extended or zero extended, depending on
+      // whether they are signed or unsigned types. This case applies
+      // only to scalar parameters and not to aggregate values.
+      const bool ExtendIntegerParam =
+          Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
+
+      const auto GetStoredValue = [&](const unsigned I) {
+        SDValue StVal = ArgOutVals[I];
+        assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
+                   StVal.getValueType() &&
+               "OutVal type should always be legal");
+
+        const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
+        const EVT StoreVT =
+            ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
+
+        return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl);
+      };
+
+      unsigned J = 0;
+      const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
+      for (const unsigned NumElts : VI) {
+        const EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
+
+        unsigned Offset;
+        if (IsVAArg) {
+          // TODO: We may need to support vector types that can be passed
+          // as scalars in variadic arguments.
+          assert(NumElts == 1 &&
+                 "Vectorization should be disabled for vaargs.");
+
+          // Align each part of the variadic argument to their type.
+          VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
+          Offset = VAOffset;
+
+          const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
+          VAOffset += DL.getTypeAllocSize(TheStoreType.getTypeForEVT(Ctx));
+        } else {
+          assert(VAOffset == 0 && "VAOffset must be 0 for non-VA args");
+          Offset = Offsets[J];
+        }
 
-      const unsigned Offset =
-          (VAOffset + ((IsVAArg && !IsByVal) ? 0 : CurOffset));
-      SDValue Ptr =
-          DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset));
+        SDValue Ptr =
+            DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset));
 
-      const MaybeAlign CurrentAlign = ExtendIntegerParam
-                                          ? MaybeAlign(std::nullopt)
-                                          : commonAlignment(ArgAlign, Offset);
+        const MaybeAlign CurrentAlign = ExtendIntegerParam
+                                            ? MaybeAlign(std::nullopt)
+                                            : commonAlignment(ArgAlign, Offset);
 
-      SDValue Val;
-      if (NumElts == 1) {
-        Val = GetStoredValue(J, EltVT, CurrentAlign);
-      } else {
-        SmallVector<SDValue, 8> StoreVals;
-        for (const unsigned K : llvm::seq(NumElts)) {
-          SDValue ValJ = GetStoredValue(J + K, EltVT, CurrentAlign);
-          if (ValJ.getValueType().isVector())
-            DAG.ExtractVectorElements(ValJ, StoreVals);
-          else
-            StoreVals.push_back(ValJ);
-        }
+        SDValue Val =
+            getBuildVectorizedValue(NumElts, dl, DAG, [&](unsigned K) {
+              return GetStoredValue(J + K);
+            });
 
-        EVT VT = EVT::getVectorVT(
-            *DAG.getContext(), StoreVals[0].getValueType(), StoreVals.size());
-        Val = DAG.getBuildVector(VT, dl, StoreVals);
-      }
+        SDValue StoreParam =
+            DAG.getStore(ArgDeclare, dl, Val, Ptr,
+                         MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
+        CallPrereqs.push_back(StoreParam);
 
-      SDValue StoreParam =
-          DAG.getStore(ArgDeclare, dl, Val, Ptr,
-                       MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
-      CallPrereqs.push_back(StoreParam);
-
-      // TODO: We may need to support vector types that can be passed
-      // as scalars in variadic arguments.
-      if (IsVAArg && !IsByVal) {
-        assert(NumElts == 1 &&
-               "Vectorization is expected to be disabled for variadics.");
-        const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
-        VAOffset +=
-            DL.getTypeAllocSize(TheStoreType.getTypeForEVT(*DAG.getContext()));
+        J += NumElts;
       }
-
-      J += NumElts;
     }
-    if (IsVAArg && IsByVal)
-      VAOffset += TypeSize;
   }
 
   // Handle Result
@@ -1676,17 +1733,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     CallPrereqs.push_back(PrototypeDeclare);
   }
 
-  if (ConvertToIndirectCall) {
-    // Copy the function ptr to a ptx register and use the register to call the
-    // function.
-    const MVT DestVT = Callee.getValueType().getSimpleVT();
-    MachineRegisterInfo &MRI = DAG.getMachineFunction().getRegInfo();
-    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-    Register DestReg = MRI.createVirtualRegister(TLI.getRegClassFor(DestVT));
-    auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee);
-    Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
-  }
-
   const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0;
   const unsigned NumArgs =
       std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size());
@@ -1703,10 +1749,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   if (!Ins.empty()) {
     SmallVector<EVT, 16> VTs;
     SmallVector<uint64_t, 16> Offsets;
-    ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
+    ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
     assert(VTs.size() == Ins.size() && "Bad value decomposition");
 
     const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
+    const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
 
     // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
     // 32-bits are sign extended or zero extended, depending on whether
@@ -1714,9 +1761,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     const bool ExtendIntegerRetVal =
         RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
 
-    const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
     unsigned I = 0;
-    for (const unsigned NumElts : VectorInfo) {
+    const auto VI = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
+    for (const unsigned NumElts : VI) {
       const MaybeAlign CurrentAlign =
           ExtendIntegerRetVal ? MaybeAlign(std::nullopt)
                               : commonAlignment(RetAlign, Offsets[I]);
@@ -1724,16 +1771,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
       const EVT LoadVT =
           ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
-
-      const unsigned PackingAmt =
-          LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;
-
-      const EVT VecVT = NumElts == 1 ? LoadVT
-                                     : EVT::getVectorVT(*DAG.getContext(),
-                                                        LoadVT.getScalarType(),
-                                                        NumElts * PackingAmt);
-
-      const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
+      const EVT VecVT = getVectorizedVT(LoadVT, NumElts, Ctx);
       SDValue Ptr =
           DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
 
@@ -1742,17 +1780,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                       MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
 
       LoadChains.push_back(R.getValue(1));
-
-      if (NumElts == 1)
-        ProxyRegOps.push_back(R);
-      else
-        for (const unsigned J : llvm::seq(NumElts)) {
-          SDValue Elt = DAG.getNode(
-              LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
-                                : ISD::EXTRACT_VECTOR_ELT,
-              dl, LoadVT, R, DAG.getVectorIdxConstant(J * PackingAmt, dl));
-          ProxyRegOps.push_back(Elt);
-        }
+      for (const unsigned J : llvm::seq(NumElts))
+        ProxyRegOps.push_back(getExtractVectorizedValue(R, J, LoadVT, dl, DAG));
       I += NumElts;
     }
   }
@@ -3227,11 +3256,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
     const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
     SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
-  MachineFunction &MF = DAG.getMachineFunction();
   const DataLayout &DL = DAG.getDataLayout();
   auto PtrVT = getPointerTy(DAG.getDataLayout());
 
-  const Function *F = &MF.getFunction();
+  const Function &F = DAG.getMachineFunction().getFunction();
 
   SDValue Root = DAG.getRoot();
   SmallVector<SDValue, 16> OutChains;
@@ -3247,7 +3275,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
   // See similar issue in LowerCall.
 
   auto AllIns = ArrayRef(Ins);
-  for (const auto &Arg : F->args()) {
+  for (const auto &Arg : F.args()) {
     const auto ArgIns = AllIns.take_while(
         [&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); });
     AllIns = AllIns.drop_front(ArgIns.size());
@@ -3287,7 +3315,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
       assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer");
 
       SDValue P;
-      if (isKernelFunction(*F)) {
+      if (isKernelFunction(F)) {
         P = ArgSymbol;
         P.getNode()->setIROrder(Arg.getArgNo() + 1);
       } else {
@@ -3305,43 +3333,27 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
       assert(VTs.size() == Offsets.size() && "Size mismatch");
 
       const Align ArgAlign = getFunctionArgumentAlignment(
-          F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
+          &F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
 
-      const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
       unsigned I = 0;
-      for (const unsigned NumElts : VectorInfo) {
+      const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
+      for (const unsigned NumElts : VI) {
         // i1 is loaded/stored as i8
         const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
-        // If the element is a packed type (ex. v2f16, v4i8, etc) holding
-        // multiple elements.
-        const unsigned PackingAmt =
-            LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;
-
-        const EVT VecVT =
-            NumElts == 1
-                ? LoadVT
-                : EVT::getVectorVT(F->getContext(), LoadVT.getScalarType(),
-                                   NumElts * PackingAmt);
+        const EVT VecVT = getVectorizedVT(LoadVT, NumElts, *DAG.getContext());
 
         SDValue VecAddr = DAG.getObjectPtrOffset(
             dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
 
-        const MaybeAlign PartAlign = commonAlignment(ArgAlign, Offsets[I]);
+        const Align PartAlign = commonAlignment(ArgAlign, Offsets[I]);
         SDValue P =
             DAG.getLoad(VecVT, dl, Root, VecAddr,
                         MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign,
                         MachineMemOperand::MODereferenceable |
                             MachineMemOperand::MOInvariant);
-        if (P.getNode())
-          P.getNode()->setIROrder(Arg.getArgNo() + 1);
+        P.getNode()->setIROrder(Arg.getArgNo() + 1);
         for (const unsigned J : llvm::seq(NumElts)) {
-          SDValue Elt =
-              NumElts == 1
-                  ? P
-                  : DAG.getNode(LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
-                                                  : ISD::EXTRACT_VECTOR_ELT,
-                                dl, LoadVT, P,
-                                DAG.getVectorIdxConstant(J * PackingAmt, dl));
+          SDValue Elt = getExtractVectorizedValue(P, J, LoadVT, dl, DAG);
 
           Elt = correctParamType(Elt, ArgIns[I + J].VT, ArgIns[I + J].Flags,
                                  DAG, dl);
@@ -3364,9 +3376,8 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
                                  const SmallVectorImpl<SDValue> &OutVals,
                                  const SDLoc &dl, SelectionDAG &DAG) const {
-  const MachineFunction &MF = DAG.getMachineFunction();
-  const Function &F = MF.getFunction();
-  Type *RetTy = MF.getFunction().getReturnType();
+  const Function &F = DAG.getMachineFunction().getFunction();
+  Type *RetTy = F.getReturnType();
 
   if (RetTy->isVoidTy()) {
     assert(OutVals.empty() && Outs.empty() && "Return value expected for void");
@@ -3374,10 +3385,9 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
   }
 
   const DataLayout &DL = DAG.getDataLayout();
-  SmallVector<EVT, 16> VTs;
-  SmallVector<uint64_t, 16> Offsets;
-  ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
-  assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
+
+  const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
+  const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
 
   // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
   // 32-bits are sign extended or zero extended, depending on whether
@@ -3385,6 +3395,11 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
   const bool ExtendIntegerRetVal =
       RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
 
+  SmallVector<EVT, 16> VTs;
+  SmallVector<uint64_t, 16> Offsets;
+  ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
+  assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
+
   const auto GetRetVal = [&](unsigned I) -> SDValue {
     SDValue RetVal = OutVals[I];
     assert(promoteScalarIntegerPTX(RetVal.getValueType()) ==
@@ -3397,33 +3412,16 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
     return correctParamType(RetVal, StoreVT, Outs[I].Flags, DAG, dl);
   };
 
-  const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
-  const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
   unsigned I = 0;
-  for (const unsigned NumElts : VectorInfo) {
+  const auto VI = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
+  for (const unsigned NumElts : VI) {
     const MaybeAlign CurrentAlign = ExtendIntegerRetVal
                                         ? MaybeAlign(std::nullopt)
                                         : commonAlignment(RetAlign, Offsets[I]);
 
-    SDValue Val;
-    if (NumElts == 1) {
-      Val = GetRetVal(I);
-    } else {
-      SmallVector<SDValue, 4> StoreVals;
-      for (const unsigned J : llvm::seq(NumElts)) {
-        SDValue ValJ = GetRetVal(I + J);
-        if (ValJ.getValueType().isVector())
-          DAG.ExtractVectorElements(ValJ, StoreVals);
-        else
-          StoreVals.push_back(ValJ);
-      }
-
-      EVT VT = EVT::getVectorVT(F.getContext(), StoreVals[0].getValueType(),
-                                StoreVals.size());
-      Val = DAG.getBuildVector(VT, dl, StoreVals);
-    }
+    SDValue Val = getBuildVectorizedValue(
+        NumElts, dl, DAG, [&](unsigned K) { return GetRetVal(I + K); });
 
-    const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
     SDValue Ptr =
         DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
 

diff  --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index d8047d31ff6f0..2ae7520417b1f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1602,8 +1602,6 @@ foreach is_convergent = [0, 1] in {
   }
 
   defvar call_inst = !cast<NVPTXInst>("CALL" # convergent_suffix);
-  def : Pat<(call is_convergent, 1, imm:$rets, imm:$params, globaladdr:$addr, imm:$proto),
-            (call_inst (to_tglobaladdr $addr), imm:$rets, imm:$params, imm:$proto)>;
   def : Pat<(call is_convergent, 1, imm:$rets, imm:$params, i32:$addr, imm:$proto),
             (call_inst $addr, imm:$rets, imm:$params, imm:$proto)>;
   def : Pat<(call is_convergent, 1, imm:$rets, imm:$params, i64:$addr, imm:$proto),
@@ -1612,10 +1610,6 @@ foreach is_convergent = [0, 1] in {
   defvar call_uni_inst = !cast<NVPTXInst>("CALL_UNI" # convergent_suffix);
   def : Pat<(call is_convergent, 0, imm:$rets, imm:$params, globaladdr:$addr, 0),
             (call_uni_inst (to_tglobaladdr $addr), imm:$rets, imm:$params)>;
-  def : Pat<(call is_convergent, 0, imm:$rets, imm:$params, i32:$addr, 0),
-            (call_uni_inst $addr, imm:$rets, imm:$params)>;
-  def : Pat<(call is_convergent, 0, imm:$rets, imm:$params, i64:$addr, 0),
-            (call_uni_inst $addr, imm:$rets, imm:$params)>;
 }
 
 def DECLARE_PARAM_array :

diff  --git a/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll b/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll
new file mode 100644
index 0000000000000..9988d5b122cc1
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll
@@ -0,0 +1,38 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mcpu=sm_70 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -mcpu=sm_70 | %ptxas-verify -arch=sm_70 %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+%struct.double2 = type { double, double }
+
+declare %struct.double2 @add(ptr align(16) byval(%struct.double2), ptr align(16) byval(%struct.double2))
+
+define void @call_byval(ptr %out, ptr %in1, ptr %in2) {
+; CHECK-LABEL: call_byval(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<12>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [call_byval_param_0];
+; CHECK-NEXT:    { // callseq 0, 0
+; CHECK-NEXT:    .param .align 16 .b8 param0[16];
+; CHECK-NEXT:    .param .align 16 .b8 param1[16];
+; CHECK-NEXT:    .param .align 8 .b8 retval0[16];
+; CHECK-NEXT:    ld.param.b64 %rd2, [call_byval_param_2];
+; CHECK-NEXT:    ld.v2.b64 {%rd3, %rd4}, [%rd2];
+; CHECK-NEXT:    st.param.v2.b64 [param1], {%rd3, %rd4};
+; CHECK-NEXT:    ld.param.b64 %rd5, [call_byval_param_1];
+; CHECK-NEXT:    ld.v2.b64 {%rd6, %rd7}, [%rd5];
+; CHECK-NEXT:    st.param.v2.b64 [param0], {%rd6, %rd7};
+; CHECK-NEXT:    call.uni (retval0), add, (param0, param1);
+; CHECK-NEXT:    ld.param.b64 %rd8, [retval0+8];
+; CHECK-NEXT:    ld.param.b64 %rd9, [retval0];
+; CHECK-NEXT:    } // callseq 0
+; CHECK-NEXT:    st.b64 [%rd1+8], %rd8;
+; CHECK-NEXT:    st.b64 [%rd1], %rd9;
+; CHECK-NEXT:    ret;
+  %call = call %struct.double2 @add(ptr align(16) byval(%struct.double2) %in1, ptr align(16) byval(%struct.double2) %in2)
+  store %struct.double2 %call, ptr %out, align 16
+  ret void
+}

diff  --git a/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll b/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
index 48209a8c88682..dd3e4ecddcd2e 100644
--- a/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
@@ -12,14 +12,14 @@ define %struct.64 @test_return_type_mismatch(ptr %p) {
 ; CHECK-NEXT:    .reg .b64 %rd<40>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.b64 %rd2, [test_return_type_mismatch_param_0];
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_return_type_mismatch_param_0];
 ; CHECK-NEXT:    { // callseq 0, 0
 ; CHECK-NEXT:    .param .b64 param0;
 ; CHECK-NEXT:    .param .align 1 .b8 retval0[8];
-; CHECK-NEXT:    st.param.b64 [param0], %rd2;
+; CHECK-NEXT:    st.param.b64 [param0], %rd1;
 ; CHECK-NEXT:    prototype_0 : .callprototype (.param .align 1 .b8 _[8]) _ (.param .b64 _);
-; CHECK-NEXT:    mov.b64 %rd1, callee;
-; CHECK-NEXT:    call (retval0), %rd1, (param0), prototype_0;
+; CHECK-NEXT:    mov.b64 %rd2, callee;
+; CHECK-NEXT:    call (retval0), %rd2, (param0), prototype_0;
 ; CHECK-NEXT:    ld.param.b8 %rd3, [retval0+7];
 ; CHECK-NEXT:    ld.param.b8 %rd4, [retval0+6];
 ; CHECK-NEXT:    ld.param.b8 %rd5, [retval0+5];
@@ -90,16 +90,16 @@ define i64 @test_param_count_mismatch(ptr %p) {
 ; CHECK-NEXT:    .reg .b64 %rd<5>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.b64 %rd2, [test_param_count_mismatch_param_0];
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_param_count_mismatch_param_0];
 ; CHECK-NEXT:    { // callseq 2, 0
 ; CHECK-NEXT:    .param .b64 param0;
 ; CHECK-NEXT:    .param .b64 param1;
 ; CHECK-NEXT:    .param .b64 retval0;
-; CHECK-NEXT:    st.param.b64 [param0], %rd2;
+; CHECK-NEXT:    st.param.b64 [param0], %rd1;
 ; CHECK-NEXT:    prototype_2 : .callprototype (.param .b64 _) _ (.param .b64 _, .param .b64 _);
 ; CHECK-NEXT:    st.param.b64 [param1], 7;
-; CHECK-NEXT:    mov.b64 %rd1, callee;
-; CHECK-NEXT:    call (retval0), %rd1, (param0, param1), prototype_2;
+; CHECK-NEXT:    mov.b64 %rd2, callee;
+; CHECK-NEXT:    call (retval0), %rd2, (param0, param1), prototype_2;
 ; CHECK-NEXT:    ld.param.b64 %rd3, [retval0];
 ; CHECK-NEXT:    } // callseq 2
 ; CHECK-NEXT:    st.param.b64 [func_retval0], %rd3;
@@ -114,14 +114,14 @@ define %struct.64 @test_return_type_mismatch_variadic(ptr %p) {
 ; CHECK-NEXT:    .reg .b64 %rd<40>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.b64 %rd2, [test_return_type_mismatch_variadic_param_0];
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_return_type_mismatch_variadic_param_0];
 ; CHECK-NEXT:    { // callseq 3, 0
 ; CHECK-NEXT:    .param .b64 param0;
 ; CHECK-NEXT:    .param .align 1 .b8 retval0[8];
-; CHECK-NEXT:    st.param.b64 [param0], %rd2;
+; CHECK-NEXT:    st.param.b64 [param0], %rd1;
 ; CHECK-NEXT:    prototype_3 : .callprototype (.param .align 1 .b8 _[8]) _ (.param .b64 _);
-; CHECK-NEXT:    mov.b64 %rd1, callee_variadic;
-; CHECK-NEXT:    call (retval0), %rd1, (param0), prototype_3;
+; CHECK-NEXT:    mov.b64 %rd2, callee_variadic;
+; CHECK-NEXT:    call (retval0), %rd2, (param0), prototype_3;
 ; CHECK-NEXT:    ld.param.b8 %rd3, [retval0+7];
 ; CHECK-NEXT:    ld.param.b8 %rd4, [retval0+6];
 ; CHECK-NEXT:    ld.param.b8 %rd5, [retval0+5];

diff  --git a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
index 38185c7bf30de..045704bdcd3fc 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
@@ -124,15 +124,15 @@ define ptx_kernel void @grid_const_escape(ptr byval(%struct.s) align 4 %input) {
 ; PTX-NEXT:    .reg .b64 %rd<4>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
-; PTX-NEXT:    mov.b64 %rd2, grid_const_escape_param_0;
-; PTX-NEXT:    cvta.param.u64 %rd3, %rd2;
+; PTX-NEXT:    mov.b64 %rd1, grid_const_escape_param_0;
+; PTX-NEXT:    cvta.param.u64 %rd2, %rd1;
 ; PTX-NEXT:    { // callseq 0, 0
 ; PTX-NEXT:    .param .b64 param0;
 ; PTX-NEXT:    .param .b32 retval0;
-; PTX-NEXT:    st.param.b64 [param0], %rd3;
+; PTX-NEXT:    st.param.b64 [param0], %rd2;
 ; PTX-NEXT:    prototype_0 : .callprototype (.param .b32 _) _ (.param .b64 _);
-; PTX-NEXT:    mov.b64 %rd1, escape;
-; PTX-NEXT:    call (retval0), %rd1, (param0), prototype_0;
+; PTX-NEXT:    mov.b64 %rd3, escape;
+; PTX-NEXT:    call (retval0), %rd3, (param0), prototype_0;
 ; PTX-NEXT:    } // callseq 0
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_escape(
@@ -157,25 +157,25 @@ define ptx_kernel void @multiple_grid_const_escape(ptr byval(%struct.s) align 4
 ; PTX-NEXT:  // %bb.0:
 ; PTX-NEXT:    mov.b64 %SPL, __local_depot4;
 ; PTX-NEXT:    cvta.local.u64 %SP, %SPL;
-; PTX-NEXT:    mov.b64 %rd2, multiple_grid_const_escape_param_0;
+; PTX-NEXT:    mov.b64 %rd1, multiple_grid_const_escape_param_0;
 ; PTX-NEXT:    ld.param.b32 %r1, [multiple_grid_const_escape_param_1];
-; PTX-NEXT:    mov.b64 %rd3, multiple_grid_const_escape_param_2;
-; PTX-NEXT:    cvta.param.u64 %rd4, %rd3;
-; PTX-NEXT:    cvta.param.u64 %rd5, %rd2;
-; PTX-NEXT:    add.u64 %rd6, %SP, 0;
-; PTX-NEXT:    add.u64 %rd7, %SPL, 0;
-; PTX-NEXT:    st.local.b32 [%rd7], %r1;
+; PTX-NEXT:    mov.b64 %rd2, multiple_grid_const_escape_param_2;
+; PTX-NEXT:    cvta.param.u64 %rd3, %rd2;
+; PTX-NEXT:    cvta.param.u64 %rd4, %rd1;
+; PTX-NEXT:    add.u64 %rd5, %SP, 0;
+; PTX-NEXT:    add.u64 %rd6, %SPL, 0;
+; PTX-NEXT:    st.local.b32 [%rd6], %r1;
 ; PTX-NEXT:    { // callseq 1, 0
 ; PTX-NEXT:    .param .b64 param0;
 ; PTX-NEXT:    .param .b64 param1;
 ; PTX-NEXT:    .param .b64 param2;
 ; PTX-NEXT:    .param .b32 retval0;
-; PTX-NEXT:    st.param.b64 [param2], %rd4;
-; PTX-NEXT:    st.param.b64 [param1], %rd6;
-; PTX-NEXT:    st.param.b64 [param0], %rd5;
+; PTX-NEXT:    st.param.b64 [param2], %rd3;
+; PTX-NEXT:    st.param.b64 [param1], %rd5;
+; PTX-NEXT:    st.param.b64 [param0], %rd4;
 ; PTX-NEXT:    prototype_1 : .callprototype (.param .b32 _) _ (.param .b64 _, .param .b64 _, .param .b64 _);
-; PTX-NEXT:    mov.b64 %rd1, escape3;
-; PTX-NEXT:    call (retval0), %rd1, (param0, param1, param2), prototype_1;
+; PTX-NEXT:    mov.b64 %rd7, escape3;
+; PTX-NEXT:    call (retval0), %rd7, (param0, param1, param2), prototype_1;
 ; PTX-NEXT:    } // callseq 1
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @multiple_grid_const_escape(
@@ -256,20 +256,20 @@ define ptx_kernel void @grid_const_partial_escape(ptr byval(i32) %input, ptr %ou
 ; PTX-NEXT:    .reg .b64 %rd<6>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
-; PTX-NEXT:    mov.b64 %rd2, grid_const_partial_escape_param_0;
-; PTX-NEXT:    ld.param.b64 %rd3, [grid_const_partial_escape_param_1];
-; PTX-NEXT:    cvta.to.global.u64 %rd4, %rd3;
-; PTX-NEXT:    cvta.param.u64 %rd5, %rd2;
+; PTX-NEXT:    mov.b64 %rd1, grid_const_partial_escape_param_0;
+; PTX-NEXT:    ld.param.b64 %rd2, [grid_const_partial_escape_param_1];
+; PTX-NEXT:    cvta.to.global.u64 %rd3, %rd2;
+; PTX-NEXT:    cvta.param.u64 %rd4, %rd1;
 ; PTX-NEXT:    ld.param.b32 %r1, [grid_const_partial_escape_param_0];
 ; PTX-NEXT:    add.s32 %r2, %r1, %r1;
-; PTX-NEXT:    st.global.b32 [%rd4], %r2;
+; PTX-NEXT:    st.global.b32 [%rd3], %r2;
 ; PTX-NEXT:    { // callseq 2, 0
 ; PTX-NEXT:    .param .b64 param0;
 ; PTX-NEXT:    .param .b32 retval0;
-; PTX-NEXT:    st.param.b64 [param0], %rd5;
+; PTX-NEXT:    st.param.b64 [param0], %rd4;
 ; PTX-NEXT:    prototype_2 : .callprototype (.param .b32 _) _ (.param .b64 _);
-; PTX-NEXT:    mov.b64 %rd1, escape;
-; PTX-NEXT:    call (retval0), %rd1, (param0), prototype_2;
+; PTX-NEXT:    mov.b64 %rd5, escape;
+; PTX-NEXT:    call (retval0), %rd5, (param0), prototype_2;
 ; PTX-NEXT:    } // callseq 2
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_partial_escape(
@@ -295,21 +295,21 @@ define ptx_kernel i32 @grid_const_partial_escapemem(ptr byval(%struct.s) %input,
 ; PTX-NEXT:    .reg .b64 %rd<6>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
-; PTX-NEXT:    mov.b64 %rd2, grid_const_partial_escapemem_param_0;
-; PTX-NEXT:    ld.param.b64 %rd3, [grid_const_partial_escapemem_param_1];
-; PTX-NEXT:    cvta.to.global.u64 %rd4, %rd3;
-; PTX-NEXT:    cvta.param.u64 %rd5, %rd2;
+; PTX-NEXT:    mov.b64 %rd1, grid_const_partial_escapemem_param_0;
+; PTX-NEXT:    ld.param.b64 %rd2, [grid_const_partial_escapemem_param_1];
+; PTX-NEXT:    cvta.to.global.u64 %rd3, %rd2;
+; PTX-NEXT:    cvta.param.u64 %rd4, %rd1;
 ; PTX-NEXT:    ld.param.b32 %r1, [grid_const_partial_escapemem_param_0];
 ; PTX-NEXT:    ld.param.b32 %r2, [grid_const_partial_escapemem_param_0+4];
-; PTX-NEXT:    st.global.b64 [%rd4], %rd5;
+; PTX-NEXT:    st.global.b64 [%rd3], %rd4;
 ; PTX-NEXT:    add.s32 %r3, %r1, %r2;
 ; PTX-NEXT:    { // callseq 3, 0
 ; PTX-NEXT:    .param .b64 param0;
 ; PTX-NEXT:    .param .b32 retval0;
-; PTX-NEXT:    st.param.b64 [param0], %rd5;
+; PTX-NEXT:    st.param.b64 [param0], %rd4;
 ; PTX-NEXT:    prototype_3 : .callprototype (.param .b32 _) _ (.param .b64 _);
-; PTX-NEXT:    mov.b64 %rd1, escape;
-; PTX-NEXT:    call (retval0), %rd1, (param0), prototype_3;
+; PTX-NEXT:    mov.b64 %rd5, escape;
+; PTX-NEXT:    call (retval0), %rd5, (param0), prototype_3;
 ; PTX-NEXT:    } // callseq 3
 ; PTX-NEXT:    st.param.b32 [func_retval0], %r3;
 ; PTX-NEXT:    ret;

diff  --git a/llvm/test/CodeGen/NVPTX/param-vectorize-device.ll b/llvm/test/CodeGen/NVPTX/param-vectorize-device.ll
index a592b82614f43..51f6b00601069 100644
--- a/llvm/test/CodeGen/NVPTX/param-vectorize-device.ll
+++ b/llvm/test/CodeGen/NVPTX/param-vectorize-device.ll
@@ -150,8 +150,8 @@ define dso_local void @caller_St4x3(ptr nocapture noundef readonly byval(%struct
   ; CHECK:       )
   ; CHECK:       .param .align 16 .b8 param0[12];
   ; CHECK:       .param .align 16 .b8 retval0[12];
-  ; CHECK:       st.param.v2.b32 [param0], {{{%r[0-9]+}}, {{%r[0-9]+}}};
-  ; CHECK:       st.param.b32    [param0+8], {{%r[0-9]+}};
+  ; CHECK-DAG:   st.param.v2.b32 [param0], {{{%r[0-9]+}}, {{%r[0-9]+}}};
+  ; CHECK-DAG:   st.param.b32    [param0+8], {{%r[0-9]+}};
   ; CHECK:       call.uni (retval0), callee_St4x3, (param0);
   ; CHECK:       ld.param.v2.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0];
   ; CHECK:       ld.param.b32    {{%r[0-9]+}},  [retval0+8];
@@ -240,8 +240,8 @@ define dso_local void @caller_St4x5(ptr nocapture noundef readonly byval(%struct
   ; CHECK:       )
   ; CHECK:       .param .align 16 .b8 param0[20];
   ; CHECK:       .param .align 16 .b8 retval0[20];
-  ; CHECK:       st.param.v4.b32 [param0],  {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
-  ; CHECK:       st.param.b32    [param0+16], {{%r[0-9]+}};
+  ; CHECK-DAG:   st.param.v4.b32 [param0],  {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
+  ; CHECK-DAG:   st.param.b32    [param0+16], {{%r[0-9]+}};
   ; CHECK:       call.uni (retval0), callee_St4x5, (param0);
   ; CHECK:       ld.param.v4.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0];
   ; CHECK:       ld.param.b32    {{%r[0-9]+}},  [retval0+16];
@@ -296,8 +296,8 @@ define dso_local void @caller_St4x6(ptr nocapture noundef readonly byval(%struct
   ; CHECK:       )
   ; CHECK:       .param .align 16 .b8 param0[24];
   ; CHECK:       .param .align 16 .b8 retval0[24];
-  ; CHECK:       st.param.v4.b32 [param0],  {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
-  ; CHECK:       st.param.v2.b32 [param0+16], {{{%r[0-9]+}}, {{%r[0-9]+}}};
+  ; CHECK-DAG:   st.param.v4.b32 [param0],  {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
+  ; CHECK-DAG:   st.param.v2.b32 [param0+16], {{{%r[0-9]+}}, {{%r[0-9]+}}};
   ; CHECK:       call.uni (retval0), callee_St4x6, (param0);
   ; CHECK:       ld.param.v4.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0];
   ; CHECK:       ld.param.v2.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0+16];
@@ -358,9 +358,9 @@ define dso_local void @caller_St4x7(ptr nocapture noundef readonly byval(%struct
   ; CHECK:       )
   ; CHECK:       .param .align 16 .b8 param0[28];
   ; CHECK:       .param .align 16 .b8 retval0[28];
-  ; CHECK:       st.param.v4.b32 [param0],  {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
-  ; CHECK:       st.param.v2.b32 [param0+16], {{{%r[0-9]+}}, {{%r[0-9]+}}};
-  ; CHECK:       st.param.b32    [param0+24], {{%r[0-9]+}};
+  ; CHECK-DAG:   st.param.v4.b32 [param0],  {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
+  ; CHECK-DAG:   st.param.v2.b32 [param0+16], {{{%r[0-9]+}}, {{%r[0-9]+}}};
+  ; CHECK-DAG:   st.param.b32    [param0+24], {{%r[0-9]+}};
   ; CHECK:       call.uni (retval0), callee_St4x7, (param0);
   ; CHECK:       ld.param.v4.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0];
   ; CHECK:       ld.param.v2.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0+16];
@@ -566,8 +566,8 @@ define dso_local void @caller_St8x3(ptr nocapture noundef readonly byval(%struct
   ; CHECK:       )
   ; CHECK:       .param .align 16 .b8 param0[24];
   ; CHECK:       .param .align 16 .b8 retval0[24];
-  ; CHECK:       st.param.v2.b64 [param0],  {{{%rd[0-9]+}}, {{%rd[0-9]+}}};
-  ; CHECK:       st.param.b64    [param0+16], {{%rd[0-9]+}};
+  ; CHECK-DAG:   st.param.v2.b64 [param0],  {{{%rd[0-9]+}}, {{%rd[0-9]+}}};
+  ; CHECK-DAG:   st.param.b64    [param0+16], {{%rd[0-9]+}};
   ; CHECK:       call.uni (retval0), callee_St8x3, (param0);
   ; CHECK:       ld.param.v2.b64 {{{%rd[0-9]+}}, {{%rd[0-9]+}}}, [retval0];
   ; CHECK:       ld.param.b64    {{%rd[0-9]+}}, [retval0+16];


        


More information about the llvm-commits mailing list