[llvm] [NVPTX] Vectorize loads when lowering of byval calls, misc. cleanup (PR #151070)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 28 19:04:08 PDT 2025


https://github.com/AlexMaclean created https://github.com/llvm/llvm-project/pull/151070

This change rewrites LowerCall handling of byval arguments to vectorize the loads in addition to the stores. In addition various minor NFC updates and cleanups are made to reduce code duplication. 

>From b28612948c87c35b76521c2e8967375364935e17 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Mon, 28 Jul 2025 22:46:52 +0000
Subject: [PATCH 1/2] pre-commit tests

---
 .../test/CodeGen/NVPTX/byval-arg-vectorize.ll | 40 +++++++++++++++++++
 1 file changed, 40 insertions(+)
 create mode 100644 llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll

diff --git a/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll b/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll
new file mode 100644
index 0000000000000..4756b16751f39
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll
@@ -0,0 +1,40 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mcpu=sm_70 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -mcpu=sm_70 | %ptxas-verify -arch=sm_70 %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+%struct.double2 = type { double, double }
+
+declare %struct.double2 @add(ptr align(16) byval(%struct.double2), ptr align(16) byval(%struct.double2))
+
+define void @call_byval(ptr %out, ptr %in1, ptr %in2) {
+; CHECK-LABEL: call_byval(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<12>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [call_byval_param_0];
+; CHECK-NEXT:    { // callseq 0, 0
+; CHECK-NEXT:    .param .align 16 .b8 param0[16];
+; CHECK-NEXT:    .param .align 16 .b8 param1[16];
+; CHECK-NEXT:    .param .align 8 .b8 retval0[16];
+; CHECK-NEXT:    ld.param.b64 %rd2, [call_byval_param_2];
+; CHECK-NEXT:    ld.b64 %rd3, [%rd2+8];
+; CHECK-NEXT:    ld.b64 %rd4, [%rd2];
+; CHECK-NEXT:    st.param.v2.b64 [param1], {%rd4, %rd3};
+; CHECK-NEXT:    ld.param.b64 %rd5, [call_byval_param_1];
+; CHECK-NEXT:    ld.b64 %rd6, [%rd5+8];
+; CHECK-NEXT:    ld.b64 %rd7, [%rd5];
+; CHECK-NEXT:    st.param.v2.b64 [param0], {%rd7, %rd6};
+; CHECK-NEXT:    call.uni (retval0), add, (param0, param1);
+; CHECK-NEXT:    ld.param.b64 %rd8, [retval0+8];
+; CHECK-NEXT:    ld.param.b64 %rd9, [retval0];
+; CHECK-NEXT:    } // callseq 0
+; CHECK-NEXT:    st.b64 [%rd1+8], %rd8;
+; CHECK-NEXT:    st.b64 [%rd1], %rd9;
+; CHECK-NEXT:    ret;
+  %call = call %struct.double2 @add(ptr align(16) byval(%struct.double2) %in1, ptr align(16) byval(%struct.double2) %in2)
+  store %struct.double2 %call, ptr %out, align 16
+  ret void
+}

>From 3937042ce5d993e5eae7fe7144dbcf3024990b95 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Tue, 29 Jul 2025 02:02:37 +0000
Subject: [PATCH 2/2] [NVPTX] Vectorize loads when lowering of byval calls,
 misc. cleanup

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   | 364 +++++++++---------
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td       |   6 -
 .../test/CodeGen/NVPTX/byval-arg-vectorize.ll |  10 +-
 .../CodeGen/NVPTX/convert-call-to-indirect.ll |  24 +-
 .../CodeGen/NVPTX/lower-args-gridconstant.ll  |  66 ++--
 .../CodeGen/NVPTX/param-vectorize-device.ll   |  22 +-
 6 files changed, 239 insertions(+), 253 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index f79b8629f01e2..0a4726212bf12 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -382,6 +382,51 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
   }
 }
 
+static EVT getVectorizedVT(EVT VT, unsigned N, LLVMContext &C) {
+  if (N == 1)
+    return VT;
+
+  const unsigned PackingAmt = VT.isVector() ? VT.getVectorNumElements() : 1;
+  return EVT::getVectorVT(C, VT.getScalarType(), N * PackingAmt);
+}
+
+static SDValue getExtractVectorizedValue(SDValue V, unsigned I, EVT VT,
+                                         const SDLoc &dl, SelectionDAG &DAG) {
+  if (V.getValueType() == VT) {
+    assert(I == 0 && "Index must be 0 for scalar value");
+    return V;
+  }
+
+  if (!VT.isVector())
+    return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, V,
+                       DAG.getVectorIdxConstant(I, dl));
+
+  return DAG.getNode(
+      ISD::EXTRACT_SUBVECTOR, dl, VT, V,
+      DAG.getVectorIdxConstant(I * VT.getVectorNumElements(), dl));
+}
+
+template <typename T>
+static inline SDValue getBuildVectorizedValue(T GetElement, unsigned N,
+                                              const SDLoc &dl,
+                                              SelectionDAG &DAG) {
+  if (N == 1)
+    return GetElement(0);
+
+  SmallVector<SDValue, 8> Values;
+  for (const unsigned I : llvm::seq(N)) {
+    SDValue Val = GetElement(I);
+    if (Val.getValueType().isVector())
+      DAG.ExtractVectorElements(Val, Values);
+    else
+      Values.push_back(Val);
+  }
+
+  EVT VT = EVT::getVectorVT(*DAG.getContext(), Values[0].getValueType(),
+                            Values.size());
+  return DAG.getBuildVector(VT, dl, Values);
+}
+
 /// PromoteScalarIntegerPTX
 /// Used to make sure the arguments/returns are suitable for passing
 /// and promote them to a larger size if they're not.
@@ -420,9 +465,10 @@ static EVT promoteScalarIntegerPTX(const EVT VT) {
 // parameter starting at index Idx using a single vectorized op of
 // size AccessSize. If so, it returns the number of param pieces
 // covered by the vector op. Otherwise, it returns 1.
-static unsigned CanMergeParamLoadStoresStartingAt(
+template <typename T>
+static unsigned canMergeParamLoadStoresStartingAt(
     unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
-    const SmallVectorImpl<uint64_t> &Offsets, Align ParamAlignment) {
+    const SmallVectorImpl<T> &Offsets, Align ParamAlignment) {
 
   // Can't vectorize if param alignment is not sufficient.
   if (ParamAlignment < AccessSize)
@@ -472,10 +518,11 @@ static unsigned CanMergeParamLoadStoresStartingAt(
 // of the same size as ValueVTs indicating how each piece should be
 // loaded/stored (i.e. as a scalar, or as part of a vector
 // load/store).
+template <typename T>
 static SmallVector<unsigned, 16>
 VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
-                     const SmallVectorImpl<uint64_t> &Offsets,
-                     Align ParamAlignment, bool IsVAArg = false) {
+                     const SmallVectorImpl<T> &Offsets, Align ParamAlignment,
+                     bool IsVAArg = false) {
   // Set vector size to match ValueVTs and mark all elements as
   // scalars by default.
 
@@ -486,7 +533,7 @@ VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
 
   const auto GetNumElts = [&](unsigned I) -> unsigned {
     for (const unsigned AccessSize : {16, 8, 4, 2}) {
-      const unsigned NumElts = CanMergeParamLoadStoresStartingAt(
+      const unsigned NumElts = canMergeParamLoadStoresStartingAt(
           I, AccessSize, ValueVTs, Offsets, ParamAlignment);
       assert((NumElts == 1 || NumElts == 2 || NumElts == 4) &&
              "Unexpected vectorization size");
@@ -1476,15 +1523,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     const SDValue ParamSymbol =
         getCallParamSymbol(DAG, IsVAArg ? FirstVAArg : ArgI, MVT::i32);
 
-    SmallVector<EVT, 16> VTs;
-    SmallVector<uint64_t, 16> Offsets;
-
     assert((!IsByVal || Arg.IndirectType) &&
            "byval arg must have indirect type");
     Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
-    ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset);
-    assert(VTs.size() == Offsets.size() && "Size mismatch");
-    assert((IsByVal || VTs.size() == ArgOuts.size()) && "Size mismatch");
 
     const Align ArgAlign = [&]() {
       if (IsByVal) {
@@ -1492,17 +1533,14 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         // so we don't need to worry whether it's naturally aligned or not.
         // See TargetLowering::LowerCallTo().
         const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
-        const Align ByValAlign = getFunctionByValParamAlign(
+        return getFunctionByValParamAlign(
             CB->getCalledFunction(), ETy, InitialAlign, DL);
-        if (IsVAArg)
-          VAOffset = alignTo(VAOffset, ByValAlign);
-        return ByValAlign;
       }
       return getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
     }();
 
-    const unsigned TypeSize = DL.getTypeAllocSize(ETy);
-    assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
+    const unsigned TySize = DL.getTypeAllocSize(ETy);
+    assert((!IsByVal || TySize == ArgOuts[0].Flags.getByValSize()) &&
            "type size mismatch");
 
     const SDValue ArgDeclare = [&]() {
@@ -1510,105 +1548,120 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         return VADeclareParam;
 
       if (IsByVal || shouldPassAsArray(Arg.Ty))
-        return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TypeSize);
+        return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TySize);
 
       assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
       assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) &&
              "Only int and float types are supported as non-array arguments");
 
-      return MakeDeclareScalarParam(ParamSymbol, TypeSize);
+      return MakeDeclareScalarParam(ParamSymbol, TySize);
     }();
 
-    // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
-    // than 32-bits are sign extended or zero extended, depending on
-    // whether they are signed or unsigned types. This case applies
-    // only to scalar parameters and not to aggregate values.
-    const bool ExtendIntegerParam =
-        Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
-
-    const auto GetStoredValue = [&](const unsigned I, EVT EltVT,
-                                    const MaybeAlign PartAlign) {
-      if (IsByVal) {
-        SDValue Ptr = ArgOutVals[0];
-        auto MPI = refinePtrAS(Ptr, DAG, DL, *this);
-        SDValue SrcAddr =
-            DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(Offsets[I]));
+    if (IsByVal) {
+      assert(ArgOutVals.size() == 1 && "We must pass only one value as byval");
+      SDValue SrcPtr = ArgOutVals[0];
+      auto PointerInfo = refinePtrAS(SrcPtr, DAG, DL, *this);
 
-        return DAG.getLoad(EltVT, dl, CallChain, SrcAddr, MPI, PartAlign);
-      }
-      SDValue StVal = ArgOutVals[I];
-      assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
-                 StVal.getValueType() &&
-             "OutVal type should always be legal");
+      if (IsVAArg)
+        VAOffset = alignTo(VAOffset, ArgAlign);
 
-      const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
-      const EVT StoreVT =
-          ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
+      SmallVector<EVT, 4> ValueVTs, MemVTs;
+      SmallVector<TypeSize, 4> Offsets;
+      ComputeValueVTs(*this, DL, ETy, ValueVTs, &MemVTs, &Offsets);
 
-      return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl);
-    };
+      unsigned J = 0;
+      const auto VI = VectorizePTXValueVTs(MemVTs, Offsets, ArgAlign, IsVAArg);
+      for (const unsigned NumElts : VI) {
+        const Align Alignment = commonAlignment(ArgAlign, Offsets[J]);
+        const EVT LoadVT =
+            getVectorizedVT(MemVTs[J], NumElts, *DAG.getContext());
 
-    const auto VectorInfo =
-        VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
+        SDValue SrcAddr = DAG.getObjectPtrOffset(dl, SrcPtr, Offsets[J]);
+        SDValue Load =
+            DAG.getLoad(LoadVT, dl, CallChain, SrcAddr, PointerInfo, Alignment);
 
-    unsigned J = 0;
-    for (const unsigned NumElts : VectorInfo) {
-      const int CurOffset = Offsets[J];
-      const EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
+        SDValue ByvalAddr = DAG.getObjectPtrOffset(
+            dl, ParamSymbol, Offsets[J].getWithIncrement(VAOffset));
 
-      if (IsVAArg && !IsByVal)
-        // Align each part of the variadic argument to their type.
-        VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
+        SDValue StoreParam =
+            DAG.getStore(ArgDeclare, dl, Load, ByvalAddr,
+                         MachinePointerInfo(ADDRESS_SPACE_PARAM), Alignment);
+        CallPrereqs.push_back(StoreParam);
 
-      assert((IsVAArg || VAOffset == 0) &&
-             "VAOffset must be 0 for non-VA args");
+        J += NumElts;
+      }
+      if (IsVAArg)
+        VAOffset += TySize;
+    } else {
+      SmallVector<EVT, 16> VTs;
+      SmallVector<uint64_t, 16> Offsets;
+      ComputePTXValueVTs(*this, DL, Arg.Ty, VTs, &Offsets, VAOffset);
+      assert(VTs.size() == Offsets.size() && "Size mismatch");
+      assert(VTs.size() == ArgOuts.size() && "Size mismatch");
+
+      // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
+      // than 32-bits are sign extended or zero extended, depending on
+      // whether they are signed or unsigned types. This case applies
+      // only to scalar parameters and not to aggregate values.
+      const bool ExtendIntegerParam =
+          Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
+
+      const auto GetStoredValue = [&](const unsigned I) {
+        SDValue StVal = ArgOutVals[I];
+        assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
+                   StVal.getValueType() &&
+               "OutVal type should always be legal");
+
+        const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
+        const EVT StoreVT =
+            ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
+
+        return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl);
+      };
+
+      unsigned J = 0;
+      const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
+      for (const unsigned NumElts : VI) {
+        const EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
+
+        unsigned Offset;
+        if (IsVAArg) {
+          // TODO: We may need to support vector types that can be passed
+          // as scalars in variadic arguments.
+          assert(NumElts == 1 &&
+                 "Vectorization should be disabled for vaargs.");
+
+          // Align each part of the variadic argument to their type.
+          VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
+          Offset = VAOffset;
+
+          const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
+          VAOffset += DL.getTypeAllocSize(
+              TheStoreType.getTypeForEVT(*DAG.getContext()));
+        } else {
+          assert(VAOffset == 0 && "VAOffset must be 0 for non-VA args");
+          Offset = Offsets[J];
+        }
 
-      const unsigned Offset =
-          (VAOffset + ((IsVAArg && !IsByVal) ? 0 : CurOffset));
-      SDValue Ptr =
-          DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset));
+        SDValue Ptr =
+            DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset));
 
-      const MaybeAlign CurrentAlign = ExtendIntegerParam
-                                          ? MaybeAlign(std::nullopt)
-                                          : commonAlignment(ArgAlign, Offset);
+        const MaybeAlign CurrentAlign = ExtendIntegerParam
+                                            ? MaybeAlign(std::nullopt)
+                                            : commonAlignment(ArgAlign, Offset);
 
-      SDValue Val;
-      if (NumElts == 1) {
-        Val = GetStoredValue(J, EltVT, CurrentAlign);
-      } else {
-        SmallVector<SDValue, 8> StoreVals;
-        for (const unsigned K : llvm::seq(NumElts)) {
-          SDValue ValJ = GetStoredValue(J + K, EltVT, CurrentAlign);
-          if (ValJ.getValueType().isVector())
-            DAG.ExtractVectorElements(ValJ, StoreVals);
-          else
-            StoreVals.push_back(ValJ);
-        }
+        SDValue Val = getBuildVectorizedValue(
+            [&](unsigned K) { return GetStoredValue(J + K); }, NumElts, dl,
+            DAG);
 
-        EVT VT = EVT::getVectorVT(
-            *DAG.getContext(), StoreVals[0].getValueType(), StoreVals.size());
-        Val = DAG.getBuildVector(VT, dl, StoreVals);
-      }
+        SDValue StoreParam =
+            DAG.getStore(ArgDeclare, dl, Val, Ptr,
+                         MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
+        CallPrereqs.push_back(StoreParam);
 
-      SDValue StoreParam =
-          DAG.getStore(ArgDeclare, dl, Val, Ptr,
-                       MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
-      CallPrereqs.push_back(StoreParam);
-
-      // TODO: We may need to support vector types that can be passed
-      // as scalars in variadic arguments.
-      if (IsVAArg && !IsByVal) {
-        assert(NumElts == 1 &&
-               "Vectorization is expected to be disabled for variadics.");
-        const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
-        VAOffset +=
-            DL.getTypeAllocSize(TheStoreType.getTypeForEVT(*DAG.getContext()));
+        J += NumElts;
       }
-
-      J += NumElts;
     }
-    if (IsVAArg && IsByVal)
-      VAOffset += TypeSize;
   }
 
   // Handle Result
@@ -1676,17 +1729,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     CallPrereqs.push_back(PrototypeDeclare);
   }
 
-  if (ConvertToIndirectCall) {
-    // Copy the function ptr to a ptx register and use the register to call the
-    // function.
-    const MVT DestVT = Callee.getValueType().getSimpleVT();
-    MachineRegisterInfo &MRI = DAG.getMachineFunction().getRegInfo();
-    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-    Register DestReg = MRI.createVirtualRegister(TLI.getRegClassFor(DestVT));
-    auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee);
-    Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
-  }
-
   const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0;
   const unsigned NumArgs =
       std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size());
@@ -1703,10 +1745,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   if (!Ins.empty()) {
     SmallVector<EVT, 16> VTs;
     SmallVector<uint64_t, 16> Offsets;
-    ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
+    ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
     assert(VTs.size() == Ins.size() && "Bad value decomposition");
 
     const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
+    const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
 
     // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
     // 32-bits are sign extended or zero extended, depending on whether
@@ -1714,9 +1757,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     const bool ExtendIntegerRetVal =
         RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
 
-    const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
     unsigned I = 0;
-    for (const unsigned NumElts : VectorInfo) {
+    const auto VI = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
+    for (const unsigned NumElts : VI) {
       const MaybeAlign CurrentAlign =
           ExtendIntegerRetVal ? MaybeAlign(std::nullopt)
                               : commonAlignment(RetAlign, Offsets[I]);
@@ -1724,16 +1767,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
       const EVT LoadVT =
           ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
-
-      const unsigned PackingAmt =
-          LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;
-
-      const EVT VecVT = NumElts == 1 ? LoadVT
-                                     : EVT::getVectorVT(*DAG.getContext(),
-                                                        LoadVT.getScalarType(),
-                                                        NumElts * PackingAmt);
-
-      const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
+      const EVT VecVT = getVectorizedVT(LoadVT, NumElts, *DAG.getContext());
       SDValue Ptr =
           DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
 
@@ -1742,17 +1776,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                       MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
 
       LoadChains.push_back(R.getValue(1));
-
-      if (NumElts == 1)
-        ProxyRegOps.push_back(R);
-      else
-        for (const unsigned J : llvm::seq(NumElts)) {
-          SDValue Elt = DAG.getNode(
-              LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
-                                : ISD::EXTRACT_VECTOR_ELT,
-              dl, LoadVT, R, DAG.getVectorIdxConstant(J * PackingAmt, dl));
-          ProxyRegOps.push_back(Elt);
-        }
+      for (const unsigned J : llvm::seq(NumElts))
+        ProxyRegOps.push_back(getExtractVectorizedValue(R, J, LoadVT, dl, DAG));
       I += NumElts;
     }
   }
@@ -3227,11 +3252,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
     const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
     SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
-  MachineFunction &MF = DAG.getMachineFunction();
   const DataLayout &DL = DAG.getDataLayout();
   auto PtrVT = getPointerTy(DAG.getDataLayout());
 
-  const Function *F = &MF.getFunction();
+  const Function &F = DAG.getMachineFunction().getFunction();
 
   SDValue Root = DAG.getRoot();
   SmallVector<SDValue, 16> OutChains;
@@ -3247,7 +3271,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
   // See similar issue in LowerCall.
 
   auto AllIns = ArrayRef(Ins);
-  for (const auto &Arg : F->args()) {
+  for (const auto &Arg : F.args()) {
     const auto ArgIns = AllIns.take_while(
         [&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); });
     AllIns = AllIns.drop_front(ArgIns.size());
@@ -3287,7 +3311,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
       assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer");
 
       SDValue P;
-      if (isKernelFunction(*F)) {
+      if (isKernelFunction(F)) {
         P = ArgSymbol;
         P.getNode()->setIROrder(Arg.getArgNo() + 1);
       } else {
@@ -3305,43 +3329,27 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
       assert(VTs.size() == Offsets.size() && "Size mismatch");
 
       const Align ArgAlign = getFunctionArgumentAlignment(
-          F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
+          &F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
 
-      const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
       unsigned I = 0;
-      for (const unsigned NumElts : VectorInfo) {
+      const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
+      for (const unsigned NumElts : VI) {
         // i1 is loaded/stored as i8
         const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
-        // If the element is a packed type (ex. v2f16, v4i8, etc) holding
-        // multiple elements.
-        const unsigned PackingAmt =
-            LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;
-
-        const EVT VecVT =
-            NumElts == 1
-                ? LoadVT
-                : EVT::getVectorVT(F->getContext(), LoadVT.getScalarType(),
-                                   NumElts * PackingAmt);
+        const EVT VecVT = getVectorizedVT(LoadVT, NumElts, *DAG.getContext());
 
         SDValue VecAddr = DAG.getObjectPtrOffset(
             dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
 
-        const MaybeAlign PartAlign = commonAlignment(ArgAlign, Offsets[I]);
+        const Align PartAlign = commonAlignment(ArgAlign, Offsets[I]);
         SDValue P =
             DAG.getLoad(VecVT, dl, Root, VecAddr,
                         MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign,
                         MachineMemOperand::MODereferenceable |
                             MachineMemOperand::MOInvariant);
-        if (P.getNode())
-          P.getNode()->setIROrder(Arg.getArgNo() + 1);
+        P.getNode()->setIROrder(Arg.getArgNo() + 1);
         for (const unsigned J : llvm::seq(NumElts)) {
-          SDValue Elt =
-              NumElts == 1
-                  ? P
-                  : DAG.getNode(LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
-                                                  : ISD::EXTRACT_VECTOR_ELT,
-                                dl, LoadVT, P,
-                                DAG.getVectorIdxConstant(J * PackingAmt, dl));
+          SDValue Elt = getExtractVectorizedValue(P, J, LoadVT, dl, DAG);
 
           Elt = correctParamType(Elt, ArgIns[I + J].VT, ArgIns[I + J].Flags,
                                  DAG, dl);
@@ -3364,9 +3372,8 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
                                  const SmallVectorImpl<SDValue> &OutVals,
                                  const SDLoc &dl, SelectionDAG &DAG) const {
-  const MachineFunction &MF = DAG.getMachineFunction();
-  const Function &F = MF.getFunction();
-  Type *RetTy = MF.getFunction().getReturnType();
+  const Function &F = DAG.getMachineFunction().getFunction();
+  Type *RetTy = F.getReturnType();
 
   if (RetTy->isVoidTy()) {
     assert(OutVals.empty() && Outs.empty() && "Return value expected for void");
@@ -3374,10 +3381,9 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
   }
 
   const DataLayout &DL = DAG.getDataLayout();
-  SmallVector<EVT, 16> VTs;
-  SmallVector<uint64_t, 16> Offsets;
-  ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
-  assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
+
+  const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
+  const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
 
   // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
   // 32-bits are sign extended or zero extended, depending on whether
@@ -3385,6 +3391,11 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
   const bool ExtendIntegerRetVal =
       RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
 
+  SmallVector<EVT, 16> VTs;
+  SmallVector<uint64_t, 16> Offsets;
+  ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
+  assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
+
   const auto GetRetVal = [&](unsigned I) -> SDValue {
     SDValue RetVal = OutVals[I];
     assert(promoteScalarIntegerPTX(RetVal.getValueType()) ==
@@ -3397,33 +3408,16 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
     return correctParamType(RetVal, StoreVT, Outs[I].Flags, DAG, dl);
   };
 
-  const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
-  const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
   unsigned I = 0;
-  for (const unsigned NumElts : VectorInfo) {
+  const auto VI = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
+  for (const unsigned NumElts : VI) {
     const MaybeAlign CurrentAlign = ExtendIntegerRetVal
                                         ? MaybeAlign(std::nullopt)
                                         : commonAlignment(RetAlign, Offsets[I]);
 
-    SDValue Val;
-    if (NumElts == 1) {
-      Val = GetRetVal(I);
-    } else {
-      SmallVector<SDValue, 4> StoreVals;
-      for (const unsigned J : llvm::seq(NumElts)) {
-        SDValue ValJ = GetRetVal(I + J);
-        if (ValJ.getValueType().isVector())
-          DAG.ExtractVectorElements(ValJ, StoreVals);
-        else
-          StoreVals.push_back(ValJ);
-      }
-
-      EVT VT = EVT::getVectorVT(F.getContext(), StoreVals[0].getValueType(),
-                                StoreVals.size());
-      Val = DAG.getBuildVector(VT, dl, StoreVals);
-    }
+    SDValue Val = getBuildVectorizedValue(
+        [&](unsigned K) { return GetRetVal(I + K); }, NumElts, dl, DAG);
 
-    const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
     SDValue Ptr =
         DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 86d6f7c3fc3a3..45d3a6044826c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1802,8 +1802,6 @@ foreach is_convergent = [0, 1] in {
   }
 
   defvar call_inst = !cast<NVPTXInst>("CALL" # convergent_suffix);
-  def : Pat<(call is_convergent, 1, imm:$rets, imm:$params, globaladdr:$addr, imm:$proto),
-            (call_inst (to_tglobaladdr $addr), imm:$rets, imm:$params, imm:$proto)>;
   def : Pat<(call is_convergent, 1, imm:$rets, imm:$params, i32:$addr, imm:$proto),
             (call_inst $addr, imm:$rets, imm:$params, imm:$proto)>;
   def : Pat<(call is_convergent, 1, imm:$rets, imm:$params, i64:$addr, imm:$proto),
@@ -1812,10 +1810,6 @@ foreach is_convergent = [0, 1] in {
   defvar call_uni_inst = !cast<NVPTXInst>("CALL_UNI" # convergent_suffix);
   def : Pat<(call is_convergent, 0, imm:$rets, imm:$params, globaladdr:$addr, 0),
             (call_uni_inst (to_tglobaladdr $addr), imm:$rets, imm:$params)>;
-  def : Pat<(call is_convergent, 0, imm:$rets, imm:$params, i32:$addr, 0),
-            (call_uni_inst $addr, imm:$rets, imm:$params)>;
-  def : Pat<(call is_convergent, 0, imm:$rets, imm:$params, i64:$addr, 0),
-            (call_uni_inst $addr, imm:$rets, imm:$params)>;
 }
 
 def DECLARE_PARAM_array :
diff --git a/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll b/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll
index 4756b16751f39..9988d5b122cc1 100644
--- a/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll
+++ b/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll
@@ -20,13 +20,11 @@ define void @call_byval(ptr %out, ptr %in1, ptr %in2) {
 ; CHECK-NEXT:    .param .align 16 .b8 param1[16];
 ; CHECK-NEXT:    .param .align 8 .b8 retval0[16];
 ; CHECK-NEXT:    ld.param.b64 %rd2, [call_byval_param_2];
-; CHECK-NEXT:    ld.b64 %rd3, [%rd2+8];
-; CHECK-NEXT:    ld.b64 %rd4, [%rd2];
-; CHECK-NEXT:    st.param.v2.b64 [param1], {%rd4, %rd3};
+; CHECK-NEXT:    ld.v2.b64 {%rd3, %rd4}, [%rd2];
+; CHECK-NEXT:    st.param.v2.b64 [param1], {%rd3, %rd4};
 ; CHECK-NEXT:    ld.param.b64 %rd5, [call_byval_param_1];
-; CHECK-NEXT:    ld.b64 %rd6, [%rd5+8];
-; CHECK-NEXT:    ld.b64 %rd7, [%rd5];
-; CHECK-NEXT:    st.param.v2.b64 [param0], {%rd7, %rd6};
+; CHECK-NEXT:    ld.v2.b64 {%rd6, %rd7}, [%rd5];
+; CHECK-NEXT:    st.param.v2.b64 [param0], {%rd6, %rd7};
 ; CHECK-NEXT:    call.uni (retval0), add, (param0, param1);
 ; CHECK-NEXT:    ld.param.b64 %rd8, [retval0+8];
 ; CHECK-NEXT:    ld.param.b64 %rd9, [retval0];
diff --git a/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll b/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
index 48209a8c88682..dd3e4ecddcd2e 100644
--- a/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
@@ -12,14 +12,14 @@ define %struct.64 @test_return_type_mismatch(ptr %p) {
 ; CHECK-NEXT:    .reg .b64 %rd<40>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.b64 %rd2, [test_return_type_mismatch_param_0];
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_return_type_mismatch_param_0];
 ; CHECK-NEXT:    { // callseq 0, 0
 ; CHECK-NEXT:    .param .b64 param0;
 ; CHECK-NEXT:    .param .align 1 .b8 retval0[8];
-; CHECK-NEXT:    st.param.b64 [param0], %rd2;
+; CHECK-NEXT:    st.param.b64 [param0], %rd1;
 ; CHECK-NEXT:    prototype_0 : .callprototype (.param .align 1 .b8 _[8]) _ (.param .b64 _);
-; CHECK-NEXT:    mov.b64 %rd1, callee;
-; CHECK-NEXT:    call (retval0), %rd1, (param0), prototype_0;
+; CHECK-NEXT:    mov.b64 %rd2, callee;
+; CHECK-NEXT:    call (retval0), %rd2, (param0), prototype_0;
 ; CHECK-NEXT:    ld.param.b8 %rd3, [retval0+7];
 ; CHECK-NEXT:    ld.param.b8 %rd4, [retval0+6];
 ; CHECK-NEXT:    ld.param.b8 %rd5, [retval0+5];
@@ -90,16 +90,16 @@ define i64 @test_param_count_mismatch(ptr %p) {
 ; CHECK-NEXT:    .reg .b64 %rd<5>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.b64 %rd2, [test_param_count_mismatch_param_0];
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_param_count_mismatch_param_0];
 ; CHECK-NEXT:    { // callseq 2, 0
 ; CHECK-NEXT:    .param .b64 param0;
 ; CHECK-NEXT:    .param .b64 param1;
 ; CHECK-NEXT:    .param .b64 retval0;
-; CHECK-NEXT:    st.param.b64 [param0], %rd2;
+; CHECK-NEXT:    st.param.b64 [param0], %rd1;
 ; CHECK-NEXT:    prototype_2 : .callprototype (.param .b64 _) _ (.param .b64 _, .param .b64 _);
 ; CHECK-NEXT:    st.param.b64 [param1], 7;
-; CHECK-NEXT:    mov.b64 %rd1, callee;
-; CHECK-NEXT:    call (retval0), %rd1, (param0, param1), prototype_2;
+; CHECK-NEXT:    mov.b64 %rd2, callee;
+; CHECK-NEXT:    call (retval0), %rd2, (param0, param1), prototype_2;
 ; CHECK-NEXT:    ld.param.b64 %rd3, [retval0];
 ; CHECK-NEXT:    } // callseq 2
 ; CHECK-NEXT:    st.param.b64 [func_retval0], %rd3;
@@ -114,14 +114,14 @@ define %struct.64 @test_return_type_mismatch_variadic(ptr %p) {
 ; CHECK-NEXT:    .reg .b64 %rd<40>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.b64 %rd2, [test_return_type_mismatch_variadic_param_0];
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_return_type_mismatch_variadic_param_0];
 ; CHECK-NEXT:    { // callseq 3, 0
 ; CHECK-NEXT:    .param .b64 param0;
 ; CHECK-NEXT:    .param .align 1 .b8 retval0[8];
-; CHECK-NEXT:    st.param.b64 [param0], %rd2;
+; CHECK-NEXT:    st.param.b64 [param0], %rd1;
 ; CHECK-NEXT:    prototype_3 : .callprototype (.param .align 1 .b8 _[8]) _ (.param .b64 _);
-; CHECK-NEXT:    mov.b64 %rd1, callee_variadic;
-; CHECK-NEXT:    call (retval0), %rd1, (param0), prototype_3;
+; CHECK-NEXT:    mov.b64 %rd2, callee_variadic;
+; CHECK-NEXT:    call (retval0), %rd2, (param0), prototype_3;
 ; CHECK-NEXT:    ld.param.b8 %rd3, [retval0+7];
 ; CHECK-NEXT:    ld.param.b8 %rd4, [retval0+6];
 ; CHECK-NEXT:    ld.param.b8 %rd5, [retval0+5];
diff --git a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
index 38185c7bf30de..045704bdcd3fc 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
@@ -124,15 +124,15 @@ define ptx_kernel void @grid_const_escape(ptr byval(%struct.s) align 4 %input) {
 ; PTX-NEXT:    .reg .b64 %rd<4>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
-; PTX-NEXT:    mov.b64 %rd2, grid_const_escape_param_0;
-; PTX-NEXT:    cvta.param.u64 %rd3, %rd2;
+; PTX-NEXT:    mov.b64 %rd1, grid_const_escape_param_0;
+; PTX-NEXT:    cvta.param.u64 %rd2, %rd1;
 ; PTX-NEXT:    { // callseq 0, 0
 ; PTX-NEXT:    .param .b64 param0;
 ; PTX-NEXT:    .param .b32 retval0;
-; PTX-NEXT:    st.param.b64 [param0], %rd3;
+; PTX-NEXT:    st.param.b64 [param0], %rd2;
 ; PTX-NEXT:    prototype_0 : .callprototype (.param .b32 _) _ (.param .b64 _);
-; PTX-NEXT:    mov.b64 %rd1, escape;
-; PTX-NEXT:    call (retval0), %rd1, (param0), prototype_0;
+; PTX-NEXT:    mov.b64 %rd3, escape;
+; PTX-NEXT:    call (retval0), %rd3, (param0), prototype_0;
 ; PTX-NEXT:    } // callseq 0
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_escape(
@@ -157,25 +157,25 @@ define ptx_kernel void @multiple_grid_const_escape(ptr byval(%struct.s) align 4
 ; PTX-NEXT:  // %bb.0:
 ; PTX-NEXT:    mov.b64 %SPL, __local_depot4;
 ; PTX-NEXT:    cvta.local.u64 %SP, %SPL;
-; PTX-NEXT:    mov.b64 %rd2, multiple_grid_const_escape_param_0;
+; PTX-NEXT:    mov.b64 %rd1, multiple_grid_const_escape_param_0;
 ; PTX-NEXT:    ld.param.b32 %r1, [multiple_grid_const_escape_param_1];
-; PTX-NEXT:    mov.b64 %rd3, multiple_grid_const_escape_param_2;
-; PTX-NEXT:    cvta.param.u64 %rd4, %rd3;
-; PTX-NEXT:    cvta.param.u64 %rd5, %rd2;
-; PTX-NEXT:    add.u64 %rd6, %SP, 0;
-; PTX-NEXT:    add.u64 %rd7, %SPL, 0;
-; PTX-NEXT:    st.local.b32 [%rd7], %r1;
+; PTX-NEXT:    mov.b64 %rd2, multiple_grid_const_escape_param_2;
+; PTX-NEXT:    cvta.param.u64 %rd3, %rd2;
+; PTX-NEXT:    cvta.param.u64 %rd4, %rd1;
+; PTX-NEXT:    add.u64 %rd5, %SP, 0;
+; PTX-NEXT:    add.u64 %rd6, %SPL, 0;
+; PTX-NEXT:    st.local.b32 [%rd6], %r1;
 ; PTX-NEXT:    { // callseq 1, 0
 ; PTX-NEXT:    .param .b64 param0;
 ; PTX-NEXT:    .param .b64 param1;
 ; PTX-NEXT:    .param .b64 param2;
 ; PTX-NEXT:    .param .b32 retval0;
-; PTX-NEXT:    st.param.b64 [param2], %rd4;
-; PTX-NEXT:    st.param.b64 [param1], %rd6;
-; PTX-NEXT:    st.param.b64 [param0], %rd5;
+; PTX-NEXT:    st.param.b64 [param2], %rd3;
+; PTX-NEXT:    st.param.b64 [param1], %rd5;
+; PTX-NEXT:    st.param.b64 [param0], %rd4;
 ; PTX-NEXT:    prototype_1 : .callprototype (.param .b32 _) _ (.param .b64 _, .param .b64 _, .param .b64 _);
-; PTX-NEXT:    mov.b64 %rd1, escape3;
-; PTX-NEXT:    call (retval0), %rd1, (param0, param1, param2), prototype_1;
+; PTX-NEXT:    mov.b64 %rd7, escape3;
+; PTX-NEXT:    call (retval0), %rd7, (param0, param1, param2), prototype_1;
 ; PTX-NEXT:    } // callseq 1
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @multiple_grid_const_escape(
@@ -256,20 +256,20 @@ define ptx_kernel void @grid_const_partial_escape(ptr byval(i32) %input, ptr %ou
 ; PTX-NEXT:    .reg .b64 %rd<6>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
-; PTX-NEXT:    mov.b64 %rd2, grid_const_partial_escape_param_0;
-; PTX-NEXT:    ld.param.b64 %rd3, [grid_const_partial_escape_param_1];
-; PTX-NEXT:    cvta.to.global.u64 %rd4, %rd3;
-; PTX-NEXT:    cvta.param.u64 %rd5, %rd2;
+; PTX-NEXT:    mov.b64 %rd1, grid_const_partial_escape_param_0;
+; PTX-NEXT:    ld.param.b64 %rd2, [grid_const_partial_escape_param_1];
+; PTX-NEXT:    cvta.to.global.u64 %rd3, %rd2;
+; PTX-NEXT:    cvta.param.u64 %rd4, %rd1;
 ; PTX-NEXT:    ld.param.b32 %r1, [grid_const_partial_escape_param_0];
 ; PTX-NEXT:    add.s32 %r2, %r1, %r1;
-; PTX-NEXT:    st.global.b32 [%rd4], %r2;
+; PTX-NEXT:    st.global.b32 [%rd3], %r2;
 ; PTX-NEXT:    { // callseq 2, 0
 ; PTX-NEXT:    .param .b64 param0;
 ; PTX-NEXT:    .param .b32 retval0;
-; PTX-NEXT:    st.param.b64 [param0], %rd5;
+; PTX-NEXT:    st.param.b64 [param0], %rd4;
 ; PTX-NEXT:    prototype_2 : .callprototype (.param .b32 _) _ (.param .b64 _);
-; PTX-NEXT:    mov.b64 %rd1, escape;
-; PTX-NEXT:    call (retval0), %rd1, (param0), prototype_2;
+; PTX-NEXT:    mov.b64 %rd5, escape;
+; PTX-NEXT:    call (retval0), %rd5, (param0), prototype_2;
 ; PTX-NEXT:    } // callseq 2
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_partial_escape(
@@ -295,21 +295,21 @@ define ptx_kernel i32 @grid_const_partial_escapemem(ptr byval(%struct.s) %input,
 ; PTX-NEXT:    .reg .b64 %rd<6>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
-; PTX-NEXT:    mov.b64 %rd2, grid_const_partial_escapemem_param_0;
-; PTX-NEXT:    ld.param.b64 %rd3, [grid_const_partial_escapemem_param_1];
-; PTX-NEXT:    cvta.to.global.u64 %rd4, %rd3;
-; PTX-NEXT:    cvta.param.u64 %rd5, %rd2;
+; PTX-NEXT:    mov.b64 %rd1, grid_const_partial_escapemem_param_0;
+; PTX-NEXT:    ld.param.b64 %rd2, [grid_const_partial_escapemem_param_1];
+; PTX-NEXT:    cvta.to.global.u64 %rd3, %rd2;
+; PTX-NEXT:    cvta.param.u64 %rd4, %rd1;
 ; PTX-NEXT:    ld.param.b32 %r1, [grid_const_partial_escapemem_param_0];
 ; PTX-NEXT:    ld.param.b32 %r2, [grid_const_partial_escapemem_param_0+4];
-; PTX-NEXT:    st.global.b64 [%rd4], %rd5;
+; PTX-NEXT:    st.global.b64 [%rd3], %rd4;
 ; PTX-NEXT:    add.s32 %r3, %r1, %r2;
 ; PTX-NEXT:    { // callseq 3, 0
 ; PTX-NEXT:    .param .b64 param0;
 ; PTX-NEXT:    .param .b32 retval0;
-; PTX-NEXT:    st.param.b64 [param0], %rd5;
+; PTX-NEXT:    st.param.b64 [param0], %rd4;
 ; PTX-NEXT:    prototype_3 : .callprototype (.param .b32 _) _ (.param .b64 _);
-; PTX-NEXT:    mov.b64 %rd1, escape;
-; PTX-NEXT:    call (retval0), %rd1, (param0), prototype_3;
+; PTX-NEXT:    mov.b64 %rd5, escape;
+; PTX-NEXT:    call (retval0), %rd5, (param0), prototype_3;
 ; PTX-NEXT:    } // callseq 3
 ; PTX-NEXT:    st.param.b32 [func_retval0], %r3;
 ; PTX-NEXT:    ret;
diff --git a/llvm/test/CodeGen/NVPTX/param-vectorize-device.ll b/llvm/test/CodeGen/NVPTX/param-vectorize-device.ll
index a592b82614f43..51f6b00601069 100644
--- a/llvm/test/CodeGen/NVPTX/param-vectorize-device.ll
+++ b/llvm/test/CodeGen/NVPTX/param-vectorize-device.ll
@@ -150,8 +150,8 @@ define dso_local void @caller_St4x3(ptr nocapture noundef readonly byval(%struct
   ; CHECK:       )
   ; CHECK:       .param .align 16 .b8 param0[12];
   ; CHECK:       .param .align 16 .b8 retval0[12];
-  ; CHECK:       st.param.v2.b32 [param0], {{{%r[0-9]+}}, {{%r[0-9]+}}};
-  ; CHECK:       st.param.b32    [param0+8], {{%r[0-9]+}};
+  ; CHECK-DAG:   st.param.v2.b32 [param0], {{{%r[0-9]+}}, {{%r[0-9]+}}};
+  ; CHECK-DAG:   st.param.b32    [param0+8], {{%r[0-9]+}};
   ; CHECK:       call.uni (retval0), callee_St4x3, (param0);
   ; CHECK:       ld.param.v2.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0];
   ; CHECK:       ld.param.b32    {{%r[0-9]+}},  [retval0+8];
@@ -240,8 +240,8 @@ define dso_local void @caller_St4x5(ptr nocapture noundef readonly byval(%struct
   ; CHECK:       )
   ; CHECK:       .param .align 16 .b8 param0[20];
   ; CHECK:       .param .align 16 .b8 retval0[20];
-  ; CHECK:       st.param.v4.b32 [param0],  {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
-  ; CHECK:       st.param.b32    [param0+16], {{%r[0-9]+}};
+  ; CHECK-DAG:   st.param.v4.b32 [param0],  {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
+  ; CHECK-DAG:   st.param.b32    [param0+16], {{%r[0-9]+}};
   ; CHECK:       call.uni (retval0), callee_St4x5, (param0);
   ; CHECK:       ld.param.v4.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0];
   ; CHECK:       ld.param.b32    {{%r[0-9]+}},  [retval0+16];
@@ -296,8 +296,8 @@ define dso_local void @caller_St4x6(ptr nocapture noundef readonly byval(%struct
   ; CHECK:       )
   ; CHECK:       .param .align 16 .b8 param0[24];
   ; CHECK:       .param .align 16 .b8 retval0[24];
-  ; CHECK:       st.param.v4.b32 [param0],  {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
-  ; CHECK:       st.param.v2.b32 [param0+16], {{{%r[0-9]+}}, {{%r[0-9]+}}};
+  ; CHECK-DAG:   st.param.v4.b32 [param0],  {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
+  ; CHECK-DAG:   st.param.v2.b32 [param0+16], {{{%r[0-9]+}}, {{%r[0-9]+}}};
   ; CHECK:       call.uni (retval0), callee_St4x6, (param0);
   ; CHECK:       ld.param.v4.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0];
   ; CHECK:       ld.param.v2.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0+16];
@@ -358,9 +358,9 @@ define dso_local void @caller_St4x7(ptr nocapture noundef readonly byval(%struct
   ; CHECK:       )
   ; CHECK:       .param .align 16 .b8 param0[28];
   ; CHECK:       .param .align 16 .b8 retval0[28];
-  ; CHECK:       st.param.v4.b32 [param0],  {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
-  ; CHECK:       st.param.v2.b32 [param0+16], {{{%r[0-9]+}}, {{%r[0-9]+}}};
-  ; CHECK:       st.param.b32    [param0+24], {{%r[0-9]+}};
+  ; CHECK-DAG:   st.param.v4.b32 [param0],  {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
+  ; CHECK-DAG:   st.param.v2.b32 [param0+16], {{{%r[0-9]+}}, {{%r[0-9]+}}};
+  ; CHECK-DAG:   st.param.b32    [param0+24], {{%r[0-9]+}};
   ; CHECK:       call.uni (retval0), callee_St4x7, (param0);
   ; CHECK:       ld.param.v4.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0];
   ; CHECK:       ld.param.v2.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0+16];
@@ -566,8 +566,8 @@ define dso_local void @caller_St8x3(ptr nocapture noundef readonly byval(%struct
   ; CHECK:       )
   ; CHECK:       .param .align 16 .b8 param0[24];
   ; CHECK:       .param .align 16 .b8 retval0[24];
-  ; CHECK:       st.param.v2.b64 [param0],  {{{%rd[0-9]+}}, {{%rd[0-9]+}}};
-  ; CHECK:       st.param.b64    [param0+16], {{%rd[0-9]+}};
+  ; CHECK-DAG:   st.param.v2.b64 [param0],  {{{%rd[0-9]+}}, {{%rd[0-9]+}}};
+  ; CHECK-DAG:   st.param.b64    [param0+16], {{%rd[0-9]+}};
   ; CHECK:       call.uni (retval0), callee_St8x3, (param0);
   ; CHECK:       ld.param.v2.b64 {{{%rd[0-9]+}}, {{%rd[0-9]+}}}, [retval0];
   ; CHECK:       ld.param.b64    {{%rd[0-9]+}}, [retval0+16];



More information about the llvm-commits mailing list