[llvm] [NVPTX] Vectorize loads when lowering of byval calls, misc. cleanup (PR #151070)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 30 22:04:34 PDT 2025


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/151070

>From b28612948c87c35b76521c2e8967375364935e17 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Mon, 28 Jul 2025 22:46:52 +0000
Subject: [PATCH 1/3] pre-commit tests

---
 .../test/CodeGen/NVPTX/byval-arg-vectorize.ll | 40 +++++++++++++++++++
 1 file changed, 40 insertions(+)
 create mode 100644 llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll

diff --git a/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll b/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll
new file mode 100644
index 0000000000000..4756b16751f39
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll
@@ -0,0 +1,40 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mcpu=sm_70 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -mcpu=sm_70 | %ptxas-verify -arch=sm_70 %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+%struct.double2 = type { double, double }
+
+declare %struct.double2 @add(ptr align(16) byval(%struct.double2), ptr align(16) byval(%struct.double2))
+
+define void @call_byval(ptr %out, ptr %in1, ptr %in2) {
+; CHECK-LABEL: call_byval(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<12>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [call_byval_param_0];
+; CHECK-NEXT:    { // callseq 0, 0
+; CHECK-NEXT:    .param .align 16 .b8 param0[16];
+; CHECK-NEXT:    .param .align 16 .b8 param1[16];
+; CHECK-NEXT:    .param .align 8 .b8 retval0[16];
+; CHECK-NEXT:    ld.param.b64 %rd2, [call_byval_param_2];
+; CHECK-NEXT:    ld.b64 %rd3, [%rd2+8];
+; CHECK-NEXT:    ld.b64 %rd4, [%rd2];
+; CHECK-NEXT:    st.param.v2.b64 [param1], {%rd4, %rd3};
+; CHECK-NEXT:    ld.param.b64 %rd5, [call_byval_param_1];
+; CHECK-NEXT:    ld.b64 %rd6, [%rd5+8];
+; CHECK-NEXT:    ld.b64 %rd7, [%rd5];
+; CHECK-NEXT:    st.param.v2.b64 [param0], {%rd7, %rd6};
+; CHECK-NEXT:    call.uni (retval0), add, (param0, param1);
+; CHECK-NEXT:    ld.param.b64 %rd8, [retval0+8];
+; CHECK-NEXT:    ld.param.b64 %rd9, [retval0];
+; CHECK-NEXT:    } // callseq 0
+; CHECK-NEXT:    st.b64 [%rd1+8], %rd8;
+; CHECK-NEXT:    st.b64 [%rd1], %rd9;
+; CHECK-NEXT:    ret;
+  %call = call %struct.double2 @add(ptr align(16) byval(%struct.double2) %in1, ptr align(16) byval(%struct.double2) %in2)
+  store %struct.double2 %call, ptr %out, align 16
+  ret void
+}

>From cdfc0e34cc1c45ef2341bd85077cd1ca9bd28303 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Tue, 29 Jul 2025 02:02:37 +0000
Subject: [PATCH 2/3] [NVPTX] Vectorize loads when lowering of byval calls,
 misc. cleanup

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   | 376 +++++++++---------
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td       |   6 -
 .../test/CodeGen/NVPTX/byval-arg-vectorize.ll |  10 +-
 .../CodeGen/NVPTX/convert-call-to-indirect.ll |  24 +-
 .../CodeGen/NVPTX/lower-args-gridconstant.ll  |  66 +--
 .../CodeGen/NVPTX/param-vectorize-device.ll   |  22 +-
 6 files changed, 245 insertions(+), 259 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index f79b8629f01e2..3e48e61f8c6c8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -382,6 +382,51 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
   }
 }
 
+static EVT getVectorizedVT(EVT VT, unsigned N, LLVMContext &C) {
+  if (N == 1)
+    return VT;
+
+  const unsigned PackingAmt = VT.isVector() ? VT.getVectorNumElements() : 1;
+  return EVT::getVectorVT(C, VT.getScalarType(), N * PackingAmt);
+}
+
+static SDValue getExtractVectorizedValue(SDValue V, unsigned I, EVT VT,
+                                         const SDLoc &dl, SelectionDAG &DAG) {
+  if (V.getValueType() == VT) {
+    assert(I == 0 && "Index must be 0 for scalar value");
+    return V;
+  }
+
+  if (!VT.isVector())
+    return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, V,
+                       DAG.getVectorIdxConstant(I, dl));
+
+  return DAG.getNode(
+      ISD::EXTRACT_SUBVECTOR, dl, VT, V,
+      DAG.getVectorIdxConstant(I * VT.getVectorNumElements(), dl));
+}
+
+template <typename T>
+static inline SDValue getBuildVectorizedValue(T GetElement, unsigned N,
+                                              const SDLoc &dl,
+                                              SelectionDAG &DAG) {
+  if (N == 1)
+    return GetElement(0);
+
+  SmallVector<SDValue, 8> Values;
+  for (const unsigned I : llvm::seq(N)) {
+    SDValue Val = GetElement(I);
+    if (Val.getValueType().isVector())
+      DAG.ExtractVectorElements(Val, Values);
+    else
+      Values.push_back(Val);
+  }
+
+  EVT VT = EVT::getVectorVT(*DAG.getContext(), Values[0].getValueType(),
+                            Values.size());
+  return DAG.getBuildVector(VT, dl, Values);
+}
+
 /// PromoteScalarIntegerPTX
 /// Used to make sure the arguments/returns are suitable for passing
 /// and promote them to a larger size if they're not.
@@ -420,9 +465,10 @@ static EVT promoteScalarIntegerPTX(const EVT VT) {
 // parameter starting at index Idx using a single vectorized op of
 // size AccessSize. If so, it returns the number of param pieces
 // covered by the vector op. Otherwise, it returns 1.
-static unsigned CanMergeParamLoadStoresStartingAt(
+template <typename T>
+static unsigned canMergeParamLoadStoresStartingAt(
     unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
-    const SmallVectorImpl<uint64_t> &Offsets, Align ParamAlignment) {
+    const SmallVectorImpl<T> &Offsets, Align ParamAlignment) {
 
   // Can't vectorize if param alignment is not sufficient.
   if (ParamAlignment < AccessSize)
@@ -472,10 +518,11 @@ static unsigned CanMergeParamLoadStoresStartingAt(
 // of the same size as ValueVTs indicating how each piece should be
 // loaded/stored (i.e. as a scalar, or as part of a vector
 // load/store).
+template <typename T>
 static SmallVector<unsigned, 16>
 VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
-                     const SmallVectorImpl<uint64_t> &Offsets,
-                     Align ParamAlignment, bool IsVAArg = false) {
+                     const SmallVectorImpl<T> &Offsets, Align ParamAlignment,
+                     bool IsVAArg = false) {
   // Set vector size to match ValueVTs and mark all elements as
   // scalars by default.
 
@@ -486,7 +533,7 @@ VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
 
   const auto GetNumElts = [&](unsigned I) -> unsigned {
     for (const unsigned AccessSize : {16, 8, 4, 2}) {
-      const unsigned NumElts = CanMergeParamLoadStoresStartingAt(
+      const unsigned NumElts = canMergeParamLoadStoresStartingAt(
           I, AccessSize, ValueVTs, Offsets, ParamAlignment);
       assert((NumElts == 1 || NumElts == 2 || NumElts == 4) &&
              "Unexpected vectorization size");
@@ -1384,6 +1431,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   Type *RetTy = CLI.RetTy;
   const CallBase *CB = CLI.CB;
   const DataLayout &DL = DAG.getDataLayout();
+  LLVMContext &Ctx = *DAG.getContext();
 
   const auto GetI32 = [&](const unsigned I) {
     return DAG.getConstant(I, dl, MVT::i32);
@@ -1476,15 +1524,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     const SDValue ParamSymbol =
         getCallParamSymbol(DAG, IsVAArg ? FirstVAArg : ArgI, MVT::i32);
 
-    SmallVector<EVT, 16> VTs;
-    SmallVector<uint64_t, 16> Offsets;
-
     assert((!IsByVal || Arg.IndirectType) &&
            "byval arg must have indirect type");
     Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
-    ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset);
-    assert(VTs.size() == Offsets.size() && "Size mismatch");
-    assert((IsByVal || VTs.size() == ArgOuts.size()) && "Size mismatch");
 
     const Align ArgAlign = [&]() {
       if (IsByVal) {
@@ -1492,17 +1534,14 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         // so we don't need to worry whether it's naturally aligned or not.
         // See TargetLowering::LowerCallTo().
         const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
-        const Align ByValAlign = getFunctionByValParamAlign(
-            CB->getCalledFunction(), ETy, InitialAlign, DL);
-        if (IsVAArg)
-          VAOffset = alignTo(VAOffset, ByValAlign);
-        return ByValAlign;
+        return getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
+                                          InitialAlign, DL);
       }
       return getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
     }();
 
-    const unsigned TypeSize = DL.getTypeAllocSize(ETy);
-    assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
+    const unsigned TySize = DL.getTypeAllocSize(ETy);
+    assert((!IsByVal || TySize == ArgOuts[0].Flags.getByValSize()) &&
            "type size mismatch");
 
     const SDValue ArgDeclare = [&]() {
@@ -1510,105 +1549,119 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         return VADeclareParam;
 
       if (IsByVal || shouldPassAsArray(Arg.Ty))
-        return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TypeSize);
+        return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TySize);
 
       assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
       assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) &&
              "Only int and float types are supported as non-array arguments");
 
-      return MakeDeclareScalarParam(ParamSymbol, TypeSize);
+      return MakeDeclareScalarParam(ParamSymbol, TySize);
     }();
 
-    // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
-    // than 32-bits are sign extended or zero extended, depending on
-    // whether they are signed or unsigned types. This case applies
-    // only to scalar parameters and not to aggregate values.
-    const bool ExtendIntegerParam =
-        Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
+    if (IsByVal) {
+      assert(ArgOutVals.size() == 1 && "We must pass only one value as byval");
+      SDValue SrcPtr = ArgOutVals[0];
+      const auto PointerInfo = refinePtrAS(SrcPtr, DAG, DL, *this);
+      const Align BaseSrcAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
 
-    const auto GetStoredValue = [&](const unsigned I, EVT EltVT,
-                                    const MaybeAlign PartAlign) {
-      if (IsByVal) {
-        SDValue Ptr = ArgOutVals[0];
-        auto MPI = refinePtrAS(Ptr, DAG, DL, *this);
-        SDValue SrcAddr =
-            DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(Offsets[I]));
-
-        return DAG.getLoad(EltVT, dl, CallChain, SrcAddr, MPI, PartAlign);
+      if (IsVAArg)
+        VAOffset = alignTo(VAOffset, ArgAlign);
+
+      SmallVector<EVT, 4> ValueVTs, MemVTs;
+      SmallVector<TypeSize, 4> Offsets;
+      ComputeValueVTs(*this, DL, ETy, ValueVTs, &MemVTs, &Offsets);
+
+      unsigned J = 0;
+      const auto VI = VectorizePTXValueVTs(MemVTs, Offsets, ArgAlign, IsVAArg);
+      for (const unsigned NumElts : VI) {
+        EVT LoadVT = getVectorizedVT(MemVTs[J], NumElts, Ctx);
+        Align SrcAlign = commonAlignment(BaseSrcAlign, Offsets[J]);
+        SDValue SrcAddr = DAG.getObjectPtrOffset(dl, SrcPtr, Offsets[J]);
+        SDValue SrcLoad =
+            DAG.getLoad(LoadVT, dl, CallChain, SrcAddr, PointerInfo, SrcAlign);
+
+        TypeSize ParamOffset = Offsets[J].getWithIncrement(VAOffset);
+        Align ParamAlign = commonAlignment(ArgAlign, ParamOffset);
+        SDValue ParamAddr =
+            DAG.getObjectPtrOffset(dl, ParamSymbol, ParamOffset);
+        SDValue StoreParam =
+            DAG.getStore(ArgDeclare, dl, SrcLoad, ParamAddr,
+                         MachinePointerInfo(ADDRESS_SPACE_PARAM), ParamAlign);
+        CallPrereqs.push_back(StoreParam);
+
+        J += NumElts;
       }
-      SDValue StVal = ArgOutVals[I];
-      assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
-                 StVal.getValueType() &&
-             "OutVal type should always be legal");
-
-      const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
-      const EVT StoreVT =
-          ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
-
-      return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl);
-    };
-
-    const auto VectorInfo =
-        VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
-
-    unsigned J = 0;
-    for (const unsigned NumElts : VectorInfo) {
-      const int CurOffset = Offsets[J];
-      const EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
-
-      if (IsVAArg && !IsByVal)
-        // Align each part of the variadic argument to their type.
-        VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
-
-      assert((IsVAArg || VAOffset == 0) &&
-             "VAOffset must be 0 for non-VA args");
+      if (IsVAArg)
+        VAOffset += TySize;
+    } else {
+      SmallVector<EVT, 16> VTs;
+      SmallVector<uint64_t, 16> Offsets;
+      ComputePTXValueVTs(*this, DL, Arg.Ty, VTs, &Offsets, VAOffset);
+      assert(VTs.size() == Offsets.size() && "Size mismatch");
+      assert(VTs.size() == ArgOuts.size() && "Size mismatch");
+
+      // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
+      // than 32-bits are sign extended or zero extended, depending on
+      // whether they are signed or unsigned types. This case applies
+      // only to scalar parameters and not to aggregate values.
+      const bool ExtendIntegerParam =
+          Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
+
+      const auto GetStoredValue = [&](const unsigned I) {
+        SDValue StVal = ArgOutVals[I];
+        assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
+                   StVal.getValueType() &&
+               "OutVal type should always be legal");
+
+        const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
+        const EVT StoreVT =
+            ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
+
+        return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl);
+      };
+
+      unsigned J = 0;
+      const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
+      for (const unsigned NumElts : VI) {
+        const EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
+
+        unsigned Offset;
+        if (IsVAArg) {
+          // TODO: We may need to support vector types that can be passed
+          // as scalars in variadic arguments.
+          assert(NumElts == 1 &&
+                 "Vectorization should be disabled for vaargs.");
+
+          // Align each part of the variadic argument to their type.
+          VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
+          Offset = VAOffset;
+
+          const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
+          VAOffset += DL.getTypeAllocSize(TheStoreType.getTypeForEVT(Ctx));
+        } else {
+          assert(VAOffset == 0 && "VAOffset must be 0 for non-VA args");
+          Offset = Offsets[J];
+        }
 
-      const unsigned Offset =
-          (VAOffset + ((IsVAArg && !IsByVal) ? 0 : CurOffset));
-      SDValue Ptr =
-          DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset));
+        SDValue Ptr =
+            DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset));
 
-      const MaybeAlign CurrentAlign = ExtendIntegerParam
-                                          ? MaybeAlign(std::nullopt)
-                                          : commonAlignment(ArgAlign, Offset);
+        const MaybeAlign CurrentAlign = ExtendIntegerParam
+                                            ? MaybeAlign(std::nullopt)
+                                            : commonAlignment(ArgAlign, Offset);
 
-      SDValue Val;
-      if (NumElts == 1) {
-        Val = GetStoredValue(J, EltVT, CurrentAlign);
-      } else {
-        SmallVector<SDValue, 8> StoreVals;
-        for (const unsigned K : llvm::seq(NumElts)) {
-          SDValue ValJ = GetStoredValue(J + K, EltVT, CurrentAlign);
-          if (ValJ.getValueType().isVector())
-            DAG.ExtractVectorElements(ValJ, StoreVals);
-          else
-            StoreVals.push_back(ValJ);
-        }
+        SDValue Val = getBuildVectorizedValue(
+            [&](unsigned K) { return GetStoredValue(J + K); }, NumElts, dl,
+            DAG);
 
-        EVT VT = EVT::getVectorVT(
-            *DAG.getContext(), StoreVals[0].getValueType(), StoreVals.size());
-        Val = DAG.getBuildVector(VT, dl, StoreVals);
-      }
+        SDValue StoreParam =
+            DAG.getStore(ArgDeclare, dl, Val, Ptr,
+                         MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
+        CallPrereqs.push_back(StoreParam);
 
-      SDValue StoreParam =
-          DAG.getStore(ArgDeclare, dl, Val, Ptr,
-                       MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
-      CallPrereqs.push_back(StoreParam);
-
-      // TODO: We may need to support vector types that can be passed
-      // as scalars in variadic arguments.
-      if (IsVAArg && !IsByVal) {
-        assert(NumElts == 1 &&
-               "Vectorization is expected to be disabled for variadics.");
-        const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
-        VAOffset +=
-            DL.getTypeAllocSize(TheStoreType.getTypeForEVT(*DAG.getContext()));
+        J += NumElts;
       }
-
-      J += NumElts;
     }
-    if (IsVAArg && IsByVal)
-      VAOffset += TypeSize;
   }
 
   // Handle Result
@@ -1676,17 +1729,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     CallPrereqs.push_back(PrototypeDeclare);
   }
 
-  if (ConvertToIndirectCall) {
-    // Copy the function ptr to a ptx register and use the register to call the
-    // function.
-    const MVT DestVT = Callee.getValueType().getSimpleVT();
-    MachineRegisterInfo &MRI = DAG.getMachineFunction().getRegInfo();
-    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-    Register DestReg = MRI.createVirtualRegister(TLI.getRegClassFor(DestVT));
-    auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee);
-    Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
-  }
-
   const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0;
   const unsigned NumArgs =
       std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size());
@@ -1703,10 +1745,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   if (!Ins.empty()) {
     SmallVector<EVT, 16> VTs;
     SmallVector<uint64_t, 16> Offsets;
-    ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
+    ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
     assert(VTs.size() == Ins.size() && "Bad value decomposition");
 
     const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
+    const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
 
     // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
     // 32-bits are sign extended or zero extended, depending on whether
@@ -1714,9 +1757,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     const bool ExtendIntegerRetVal =
         RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
 
-    const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
     unsigned I = 0;
-    for (const unsigned NumElts : VectorInfo) {
+    const auto VI = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
+    for (const unsigned NumElts : VI) {
       const MaybeAlign CurrentAlign =
           ExtendIntegerRetVal ? MaybeAlign(std::nullopt)
                               : commonAlignment(RetAlign, Offsets[I]);
@@ -1724,16 +1767,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
       const EVT LoadVT =
           ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
-
-      const unsigned PackingAmt =
-          LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;
-
-      const EVT VecVT = NumElts == 1 ? LoadVT
-                                     : EVT::getVectorVT(*DAG.getContext(),
-                                                        LoadVT.getScalarType(),
-                                                        NumElts * PackingAmt);
-
-      const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
+      const EVT VecVT = getVectorizedVT(LoadVT, NumElts, Ctx);
       SDValue Ptr =
           DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
 
@@ -1742,17 +1776,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                       MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
 
       LoadChains.push_back(R.getValue(1));
-
-      if (NumElts == 1)
-        ProxyRegOps.push_back(R);
-      else
-        for (const unsigned J : llvm::seq(NumElts)) {
-          SDValue Elt = DAG.getNode(
-              LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
-                                : ISD::EXTRACT_VECTOR_ELT,
-              dl, LoadVT, R, DAG.getVectorIdxConstant(J * PackingAmt, dl));
-          ProxyRegOps.push_back(Elt);
-        }
+      for (const unsigned J : llvm::seq(NumElts))
+        ProxyRegOps.push_back(getExtractVectorizedValue(R, J, LoadVT, dl, DAG));
       I += NumElts;
     }
   }
@@ -3227,11 +3252,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
     const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
     SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
-  MachineFunction &MF = DAG.getMachineFunction();
   const DataLayout &DL = DAG.getDataLayout();
   auto PtrVT = getPointerTy(DAG.getDataLayout());
 
-  const Function *F = &MF.getFunction();
+  const Function &F = DAG.getMachineFunction().getFunction();
 
   SDValue Root = DAG.getRoot();
   SmallVector<SDValue, 16> OutChains;
@@ -3247,7 +3271,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
   // See similar issue in LowerCall.
 
   auto AllIns = ArrayRef(Ins);
-  for (const auto &Arg : F->args()) {
+  for (const auto &Arg : F.args()) {
     const auto ArgIns = AllIns.take_while(
         [&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); });
     AllIns = AllIns.drop_front(ArgIns.size());
@@ -3287,7 +3311,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
       assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer");
 
       SDValue P;
-      if (isKernelFunction(*F)) {
+      if (isKernelFunction(F)) {
         P = ArgSymbol;
         P.getNode()->setIROrder(Arg.getArgNo() + 1);
       } else {
@@ -3305,43 +3329,27 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
       assert(VTs.size() == Offsets.size() && "Size mismatch");
 
       const Align ArgAlign = getFunctionArgumentAlignment(
-          F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
+          &F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
 
-      const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
       unsigned I = 0;
-      for (const unsigned NumElts : VectorInfo) {
+      const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
+      for (const unsigned NumElts : VI) {
         // i1 is loaded/stored as i8
         const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
-        // If the element is a packed type (ex. v2f16, v4i8, etc) holding
-        // multiple elements.
-        const unsigned PackingAmt =
-            LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;
-
-        const EVT VecVT =
-            NumElts == 1
-                ? LoadVT
-                : EVT::getVectorVT(F->getContext(), LoadVT.getScalarType(),
-                                   NumElts * PackingAmt);
+        const EVT VecVT = getVectorizedVT(LoadVT, NumElts, *DAG.getContext());
 
         SDValue VecAddr = DAG.getObjectPtrOffset(
             dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
 
-        const MaybeAlign PartAlign = commonAlignment(ArgAlign, Offsets[I]);
+        const Align PartAlign = commonAlignment(ArgAlign, Offsets[I]);
         SDValue P =
             DAG.getLoad(VecVT, dl, Root, VecAddr,
                         MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign,
                         MachineMemOperand::MODereferenceable |
                             MachineMemOperand::MOInvariant);
-        if (P.getNode())
-          P.getNode()->setIROrder(Arg.getArgNo() + 1);
+        P.getNode()->setIROrder(Arg.getArgNo() + 1);
         for (const unsigned J : llvm::seq(NumElts)) {
-          SDValue Elt =
-              NumElts == 1
-                  ? P
-                  : DAG.getNode(LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
-                                                  : ISD::EXTRACT_VECTOR_ELT,
-                                dl, LoadVT, P,
-                                DAG.getVectorIdxConstant(J * PackingAmt, dl));
+          SDValue Elt = getExtractVectorizedValue(P, J, LoadVT, dl, DAG);
 
           Elt = correctParamType(Elt, ArgIns[I + J].VT, ArgIns[I + J].Flags,
                                  DAG, dl);
@@ -3364,9 +3372,8 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
                                  const SmallVectorImpl<SDValue> &OutVals,
                                  const SDLoc &dl, SelectionDAG &DAG) const {
-  const MachineFunction &MF = DAG.getMachineFunction();
-  const Function &F = MF.getFunction();
-  Type *RetTy = MF.getFunction().getReturnType();
+  const Function &F = DAG.getMachineFunction().getFunction();
+  Type *RetTy = F.getReturnType();
 
   if (RetTy->isVoidTy()) {
     assert(OutVals.empty() && Outs.empty() && "Return value expected for void");
@@ -3374,10 +3381,9 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
   }
 
   const DataLayout &DL = DAG.getDataLayout();
-  SmallVector<EVT, 16> VTs;
-  SmallVector<uint64_t, 16> Offsets;
-  ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
-  assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
+
+  const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
+  const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
 
   // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
   // 32-bits are sign extended or zero extended, depending on whether
@@ -3385,6 +3391,11 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
   const bool ExtendIntegerRetVal =
       RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
 
+  SmallVector<EVT, 16> VTs;
+  SmallVector<uint64_t, 16> Offsets;
+  ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
+  assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
+
   const auto GetRetVal = [&](unsigned I) -> SDValue {
     SDValue RetVal = OutVals[I];
     assert(promoteScalarIntegerPTX(RetVal.getValueType()) ==
@@ -3397,33 +3408,16 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
     return correctParamType(RetVal, StoreVT, Outs[I].Flags, DAG, dl);
   };
 
-  const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
-  const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
   unsigned I = 0;
-  for (const unsigned NumElts : VectorInfo) {
+  const auto VI = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
+  for (const unsigned NumElts : VI) {
     const MaybeAlign CurrentAlign = ExtendIntegerRetVal
                                         ? MaybeAlign(std::nullopt)
                                         : commonAlignment(RetAlign, Offsets[I]);
 
-    SDValue Val;
-    if (NumElts == 1) {
-      Val = GetRetVal(I);
-    } else {
-      SmallVector<SDValue, 4> StoreVals;
-      for (const unsigned J : llvm::seq(NumElts)) {
-        SDValue ValJ = GetRetVal(I + J);
-        if (ValJ.getValueType().isVector())
-          DAG.ExtractVectorElements(ValJ, StoreVals);
-        else
-          StoreVals.push_back(ValJ);
-      }
-
-      EVT VT = EVT::getVectorVT(F.getContext(), StoreVals[0].getValueType(),
-                                StoreVals.size());
-      Val = DAG.getBuildVector(VT, dl, StoreVals);
-    }
+    SDValue Val = getBuildVectorizedValue(
+        [&](unsigned K) { return GetRetVal(I + K); }, NumElts, dl, DAG);
 
-    const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
     SDValue Ptr =
         DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 86d6f7c3fc3a3..45d3a6044826c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1802,8 +1802,6 @@ foreach is_convergent = [0, 1] in {
   }
 
   defvar call_inst = !cast<NVPTXInst>("CALL" # convergent_suffix);
-  def : Pat<(call is_convergent, 1, imm:$rets, imm:$params, globaladdr:$addr, imm:$proto),
-            (call_inst (to_tglobaladdr $addr), imm:$rets, imm:$params, imm:$proto)>;
   def : Pat<(call is_convergent, 1, imm:$rets, imm:$params, i32:$addr, imm:$proto),
             (call_inst $addr, imm:$rets, imm:$params, imm:$proto)>;
   def : Pat<(call is_convergent, 1, imm:$rets, imm:$params, i64:$addr, imm:$proto),
@@ -1812,10 +1810,6 @@ foreach is_convergent = [0, 1] in {
   defvar call_uni_inst = !cast<NVPTXInst>("CALL_UNI" # convergent_suffix);
   def : Pat<(call is_convergent, 0, imm:$rets, imm:$params, globaladdr:$addr, 0),
             (call_uni_inst (to_tglobaladdr $addr), imm:$rets, imm:$params)>;
-  def : Pat<(call is_convergent, 0, imm:$rets, imm:$params, i32:$addr, 0),
-            (call_uni_inst $addr, imm:$rets, imm:$params)>;
-  def : Pat<(call is_convergent, 0, imm:$rets, imm:$params, i64:$addr, 0),
-            (call_uni_inst $addr, imm:$rets, imm:$params)>;
 }
 
 def DECLARE_PARAM_array :
diff --git a/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll b/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll
index 4756b16751f39..9988d5b122cc1 100644
--- a/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll
+++ b/llvm/test/CodeGen/NVPTX/byval-arg-vectorize.ll
@@ -20,13 +20,11 @@ define void @call_byval(ptr %out, ptr %in1, ptr %in2) {
 ; CHECK-NEXT:    .param .align 16 .b8 param1[16];
 ; CHECK-NEXT:    .param .align 8 .b8 retval0[16];
 ; CHECK-NEXT:    ld.param.b64 %rd2, [call_byval_param_2];
-; CHECK-NEXT:    ld.b64 %rd3, [%rd2+8];
-; CHECK-NEXT:    ld.b64 %rd4, [%rd2];
-; CHECK-NEXT:    st.param.v2.b64 [param1], {%rd4, %rd3};
+; CHECK-NEXT:    ld.v2.b64 {%rd3, %rd4}, [%rd2];
+; CHECK-NEXT:    st.param.v2.b64 [param1], {%rd3, %rd4};
 ; CHECK-NEXT:    ld.param.b64 %rd5, [call_byval_param_1];
-; CHECK-NEXT:    ld.b64 %rd6, [%rd5+8];
-; CHECK-NEXT:    ld.b64 %rd7, [%rd5];
-; CHECK-NEXT:    st.param.v2.b64 [param0], {%rd7, %rd6};
+; CHECK-NEXT:    ld.v2.b64 {%rd6, %rd7}, [%rd5];
+; CHECK-NEXT:    st.param.v2.b64 [param0], {%rd6, %rd7};
 ; CHECK-NEXT:    call.uni (retval0), add, (param0, param1);
 ; CHECK-NEXT:    ld.param.b64 %rd8, [retval0+8];
 ; CHECK-NEXT:    ld.param.b64 %rd9, [retval0];
diff --git a/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll b/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
index 48209a8c88682..dd3e4ecddcd2e 100644
--- a/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
@@ -12,14 +12,14 @@ define %struct.64 @test_return_type_mismatch(ptr %p) {
 ; CHECK-NEXT:    .reg .b64 %rd<40>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.b64 %rd2, [test_return_type_mismatch_param_0];
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_return_type_mismatch_param_0];
 ; CHECK-NEXT:    { // callseq 0, 0
 ; CHECK-NEXT:    .param .b64 param0;
 ; CHECK-NEXT:    .param .align 1 .b8 retval0[8];
-; CHECK-NEXT:    st.param.b64 [param0], %rd2;
+; CHECK-NEXT:    st.param.b64 [param0], %rd1;
 ; CHECK-NEXT:    prototype_0 : .callprototype (.param .align 1 .b8 _[8]) _ (.param .b64 _);
-; CHECK-NEXT:    mov.b64 %rd1, callee;
-; CHECK-NEXT:    call (retval0), %rd1, (param0), prototype_0;
+; CHECK-NEXT:    mov.b64 %rd2, callee;
+; CHECK-NEXT:    call (retval0), %rd2, (param0), prototype_0;
 ; CHECK-NEXT:    ld.param.b8 %rd3, [retval0+7];
 ; CHECK-NEXT:    ld.param.b8 %rd4, [retval0+6];
 ; CHECK-NEXT:    ld.param.b8 %rd5, [retval0+5];
@@ -90,16 +90,16 @@ define i64 @test_param_count_mismatch(ptr %p) {
 ; CHECK-NEXT:    .reg .b64 %rd<5>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.b64 %rd2, [test_param_count_mismatch_param_0];
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_param_count_mismatch_param_0];
 ; CHECK-NEXT:    { // callseq 2, 0
 ; CHECK-NEXT:    .param .b64 param0;
 ; CHECK-NEXT:    .param .b64 param1;
 ; CHECK-NEXT:    .param .b64 retval0;
-; CHECK-NEXT:    st.param.b64 [param0], %rd2;
+; CHECK-NEXT:    st.param.b64 [param0], %rd1;
 ; CHECK-NEXT:    prototype_2 : .callprototype (.param .b64 _) _ (.param .b64 _, .param .b64 _);
 ; CHECK-NEXT:    st.param.b64 [param1], 7;
-; CHECK-NEXT:    mov.b64 %rd1, callee;
-; CHECK-NEXT:    call (retval0), %rd1, (param0, param1), prototype_2;
+; CHECK-NEXT:    mov.b64 %rd2, callee;
+; CHECK-NEXT:    call (retval0), %rd2, (param0, param1), prototype_2;
 ; CHECK-NEXT:    ld.param.b64 %rd3, [retval0];
 ; CHECK-NEXT:    } // callseq 2
 ; CHECK-NEXT:    st.param.b64 [func_retval0], %rd3;
@@ -114,14 +114,14 @@ define %struct.64 @test_return_type_mismatch_variadic(ptr %p) {
 ; CHECK-NEXT:    .reg .b64 %rd<40>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.b64 %rd2, [test_return_type_mismatch_variadic_param_0];
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_return_type_mismatch_variadic_param_0];
 ; CHECK-NEXT:    { // callseq 3, 0
 ; CHECK-NEXT:    .param .b64 param0;
 ; CHECK-NEXT:    .param .align 1 .b8 retval0[8];
-; CHECK-NEXT:    st.param.b64 [param0], %rd2;
+; CHECK-NEXT:    st.param.b64 [param0], %rd1;
 ; CHECK-NEXT:    prototype_3 : .callprototype (.param .align 1 .b8 _[8]) _ (.param .b64 _);
-; CHECK-NEXT:    mov.b64 %rd1, callee_variadic;
-; CHECK-NEXT:    call (retval0), %rd1, (param0), prototype_3;
+; CHECK-NEXT:    mov.b64 %rd2, callee_variadic;
+; CHECK-NEXT:    call (retval0), %rd2, (param0), prototype_3;
 ; CHECK-NEXT:    ld.param.b8 %rd3, [retval0+7];
 ; CHECK-NEXT:    ld.param.b8 %rd4, [retval0+6];
 ; CHECK-NEXT:    ld.param.b8 %rd5, [retval0+5];
diff --git a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
index 38185c7bf30de..045704bdcd3fc 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
@@ -124,15 +124,15 @@ define ptx_kernel void @grid_const_escape(ptr byval(%struct.s) align 4 %input) {
 ; PTX-NEXT:    .reg .b64 %rd<4>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
-; PTX-NEXT:    mov.b64 %rd2, grid_const_escape_param_0;
-; PTX-NEXT:    cvta.param.u64 %rd3, %rd2;
+; PTX-NEXT:    mov.b64 %rd1, grid_const_escape_param_0;
+; PTX-NEXT:    cvta.param.u64 %rd2, %rd1;
 ; PTX-NEXT:    { // callseq 0, 0
 ; PTX-NEXT:    .param .b64 param0;
 ; PTX-NEXT:    .param .b32 retval0;
-; PTX-NEXT:    st.param.b64 [param0], %rd3;
+; PTX-NEXT:    st.param.b64 [param0], %rd2;
 ; PTX-NEXT:    prototype_0 : .callprototype (.param .b32 _) _ (.param .b64 _);
-; PTX-NEXT:    mov.b64 %rd1, escape;
-; PTX-NEXT:    call (retval0), %rd1, (param0), prototype_0;
+; PTX-NEXT:    mov.b64 %rd3, escape;
+; PTX-NEXT:    call (retval0), %rd3, (param0), prototype_0;
 ; PTX-NEXT:    } // callseq 0
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_escape(
@@ -157,25 +157,25 @@ define ptx_kernel void @multiple_grid_const_escape(ptr byval(%struct.s) align 4
 ; PTX-NEXT:  // %bb.0:
 ; PTX-NEXT:    mov.b64 %SPL, __local_depot4;
 ; PTX-NEXT:    cvta.local.u64 %SP, %SPL;
-; PTX-NEXT:    mov.b64 %rd2, multiple_grid_const_escape_param_0;
+; PTX-NEXT:    mov.b64 %rd1, multiple_grid_const_escape_param_0;
 ; PTX-NEXT:    ld.param.b32 %r1, [multiple_grid_const_escape_param_1];
-; PTX-NEXT:    mov.b64 %rd3, multiple_grid_const_escape_param_2;
-; PTX-NEXT:    cvta.param.u64 %rd4, %rd3;
-; PTX-NEXT:    cvta.param.u64 %rd5, %rd2;
-; PTX-NEXT:    add.u64 %rd6, %SP, 0;
-; PTX-NEXT:    add.u64 %rd7, %SPL, 0;
-; PTX-NEXT:    st.local.b32 [%rd7], %r1;
+; PTX-NEXT:    mov.b64 %rd2, multiple_grid_const_escape_param_2;
+; PTX-NEXT:    cvta.param.u64 %rd3, %rd2;
+; PTX-NEXT:    cvta.param.u64 %rd4, %rd1;
+; PTX-NEXT:    add.u64 %rd5, %SP, 0;
+; PTX-NEXT:    add.u64 %rd6, %SPL, 0;
+; PTX-NEXT:    st.local.b32 [%rd6], %r1;
 ; PTX-NEXT:    { // callseq 1, 0
 ; PTX-NEXT:    .param .b64 param0;
 ; PTX-NEXT:    .param .b64 param1;
 ; PTX-NEXT:    .param .b64 param2;
 ; PTX-NEXT:    .param .b32 retval0;
-; PTX-NEXT:    st.param.b64 [param2], %rd4;
-; PTX-NEXT:    st.param.b64 [param1], %rd6;
-; PTX-NEXT:    st.param.b64 [param0], %rd5;
+; PTX-NEXT:    st.param.b64 [param2], %rd3;
+; PTX-NEXT:    st.param.b64 [param1], %rd5;
+; PTX-NEXT:    st.param.b64 [param0], %rd4;
 ; PTX-NEXT:    prototype_1 : .callprototype (.param .b32 _) _ (.param .b64 _, .param .b64 _, .param .b64 _);
-; PTX-NEXT:    mov.b64 %rd1, escape3;
-; PTX-NEXT:    call (retval0), %rd1, (param0, param1, param2), prototype_1;
+; PTX-NEXT:    mov.b64 %rd7, escape3;
+; PTX-NEXT:    call (retval0), %rd7, (param0, param1, param2), prototype_1;
 ; PTX-NEXT:    } // callseq 1
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @multiple_grid_const_escape(
@@ -256,20 +256,20 @@ define ptx_kernel void @grid_const_partial_escape(ptr byval(i32) %input, ptr %ou
 ; PTX-NEXT:    .reg .b64 %rd<6>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
-; PTX-NEXT:    mov.b64 %rd2, grid_const_partial_escape_param_0;
-; PTX-NEXT:    ld.param.b64 %rd3, [grid_const_partial_escape_param_1];
-; PTX-NEXT:    cvta.to.global.u64 %rd4, %rd3;
-; PTX-NEXT:    cvta.param.u64 %rd5, %rd2;
+; PTX-NEXT:    mov.b64 %rd1, grid_const_partial_escape_param_0;
+; PTX-NEXT:    ld.param.b64 %rd2, [grid_const_partial_escape_param_1];
+; PTX-NEXT:    cvta.to.global.u64 %rd3, %rd2;
+; PTX-NEXT:    cvta.param.u64 %rd4, %rd1;
 ; PTX-NEXT:    ld.param.b32 %r1, [grid_const_partial_escape_param_0];
 ; PTX-NEXT:    add.s32 %r2, %r1, %r1;
-; PTX-NEXT:    st.global.b32 [%rd4], %r2;
+; PTX-NEXT:    st.global.b32 [%rd3], %r2;
 ; PTX-NEXT:    { // callseq 2, 0
 ; PTX-NEXT:    .param .b64 param0;
 ; PTX-NEXT:    .param .b32 retval0;
-; PTX-NEXT:    st.param.b64 [param0], %rd5;
+; PTX-NEXT:    st.param.b64 [param0], %rd4;
 ; PTX-NEXT:    prototype_2 : .callprototype (.param .b32 _) _ (.param .b64 _);
-; PTX-NEXT:    mov.b64 %rd1, escape;
-; PTX-NEXT:    call (retval0), %rd1, (param0), prototype_2;
+; PTX-NEXT:    mov.b64 %rd5, escape;
+; PTX-NEXT:    call (retval0), %rd5, (param0), prototype_2;
 ; PTX-NEXT:    } // callseq 2
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_partial_escape(
@@ -295,21 +295,21 @@ define ptx_kernel i32 @grid_const_partial_escapemem(ptr byval(%struct.s) %input,
 ; PTX-NEXT:    .reg .b64 %rd<6>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
-; PTX-NEXT:    mov.b64 %rd2, grid_const_partial_escapemem_param_0;
-; PTX-NEXT:    ld.param.b64 %rd3, [grid_const_partial_escapemem_param_1];
-; PTX-NEXT:    cvta.to.global.u64 %rd4, %rd3;
-; PTX-NEXT:    cvta.param.u64 %rd5, %rd2;
+; PTX-NEXT:    mov.b64 %rd1, grid_const_partial_escapemem_param_0;
+; PTX-NEXT:    ld.param.b64 %rd2, [grid_const_partial_escapemem_param_1];
+; PTX-NEXT:    cvta.to.global.u64 %rd3, %rd2;
+; PTX-NEXT:    cvta.param.u64 %rd4, %rd1;
 ; PTX-NEXT:    ld.param.b32 %r1, [grid_const_partial_escapemem_param_0];
 ; PTX-NEXT:    ld.param.b32 %r2, [grid_const_partial_escapemem_param_0+4];
-; PTX-NEXT:    st.global.b64 [%rd4], %rd5;
+; PTX-NEXT:    st.global.b64 [%rd3], %rd4;
 ; PTX-NEXT:    add.s32 %r3, %r1, %r2;
 ; PTX-NEXT:    { // callseq 3, 0
 ; PTX-NEXT:    .param .b64 param0;
 ; PTX-NEXT:    .param .b32 retval0;
-; PTX-NEXT:    st.param.b64 [param0], %rd5;
+; PTX-NEXT:    st.param.b64 [param0], %rd4;
 ; PTX-NEXT:    prototype_3 : .callprototype (.param .b32 _) _ (.param .b64 _);
-; PTX-NEXT:    mov.b64 %rd1, escape;
-; PTX-NEXT:    call (retval0), %rd1, (param0), prototype_3;
+; PTX-NEXT:    mov.b64 %rd5, escape;
+; PTX-NEXT:    call (retval0), %rd5, (param0), prototype_3;
 ; PTX-NEXT:    } // callseq 3
 ; PTX-NEXT:    st.param.b32 [func_retval0], %r3;
 ; PTX-NEXT:    ret;
diff --git a/llvm/test/CodeGen/NVPTX/param-vectorize-device.ll b/llvm/test/CodeGen/NVPTX/param-vectorize-device.ll
index a592b82614f43..51f6b00601069 100644
--- a/llvm/test/CodeGen/NVPTX/param-vectorize-device.ll
+++ b/llvm/test/CodeGen/NVPTX/param-vectorize-device.ll
@@ -150,8 +150,8 @@ define dso_local void @caller_St4x3(ptr nocapture noundef readonly byval(%struct
   ; CHECK:       )
   ; CHECK:       .param .align 16 .b8 param0[12];
   ; CHECK:       .param .align 16 .b8 retval0[12];
-  ; CHECK:       st.param.v2.b32 [param0], {{{%r[0-9]+}}, {{%r[0-9]+}}};
-  ; CHECK:       st.param.b32    [param0+8], {{%r[0-9]+}};
+  ; CHECK-DAG:   st.param.v2.b32 [param0], {{{%r[0-9]+}}, {{%r[0-9]+}}};
+  ; CHECK-DAG:   st.param.b32    [param0+8], {{%r[0-9]+}};
   ; CHECK:       call.uni (retval0), callee_St4x3, (param0);
   ; CHECK:       ld.param.v2.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0];
   ; CHECK:       ld.param.b32    {{%r[0-9]+}},  [retval0+8];
@@ -240,8 +240,8 @@ define dso_local void @caller_St4x5(ptr nocapture noundef readonly byval(%struct
   ; CHECK:       )
   ; CHECK:       .param .align 16 .b8 param0[20];
   ; CHECK:       .param .align 16 .b8 retval0[20];
-  ; CHECK:       st.param.v4.b32 [param0],  {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
-  ; CHECK:       st.param.b32    [param0+16], {{%r[0-9]+}};
+  ; CHECK-DAG:   st.param.v4.b32 [param0],  {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
+  ; CHECK-DAG:   st.param.b32    [param0+16], {{%r[0-9]+}};
   ; CHECK:       call.uni (retval0), callee_St4x5, (param0);
   ; CHECK:       ld.param.v4.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0];
   ; CHECK:       ld.param.b32    {{%r[0-9]+}},  [retval0+16];
@@ -296,8 +296,8 @@ define dso_local void @caller_St4x6(ptr nocapture noundef readonly byval(%struct
   ; CHECK:       )
   ; CHECK:       .param .align 16 .b8 param0[24];
   ; CHECK:       .param .align 16 .b8 retval0[24];
-  ; CHECK:       st.param.v4.b32 [param0],  {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
-  ; CHECK:       st.param.v2.b32 [param0+16], {{{%r[0-9]+}}, {{%r[0-9]+}}};
+  ; CHECK-DAG:   st.param.v4.b32 [param0],  {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
+  ; CHECK-DAG:   st.param.v2.b32 [param0+16], {{{%r[0-9]+}}, {{%r[0-9]+}}};
   ; CHECK:       call.uni (retval0), callee_St4x6, (param0);
   ; CHECK:       ld.param.v4.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0];
   ; CHECK:       ld.param.v2.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0+16];
@@ -358,9 +358,9 @@ define dso_local void @caller_St4x7(ptr nocapture noundef readonly byval(%struct
   ; CHECK:       )
   ; CHECK:       .param .align 16 .b8 param0[28];
   ; CHECK:       .param .align 16 .b8 retval0[28];
-  ; CHECK:       st.param.v4.b32 [param0],  {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
-  ; CHECK:       st.param.v2.b32 [param0+16], {{{%r[0-9]+}}, {{%r[0-9]+}}};
-  ; CHECK:       st.param.b32    [param0+24], {{%r[0-9]+}};
+  ; CHECK-DAG:   st.param.v4.b32 [param0],  {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}};
+  ; CHECK-DAG:   st.param.v2.b32 [param0+16], {{{%r[0-9]+}}, {{%r[0-9]+}}};
+  ; CHECK-DAG:   st.param.b32    [param0+24], {{%r[0-9]+}};
   ; CHECK:       call.uni (retval0), callee_St4x7, (param0);
   ; CHECK:       ld.param.v4.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0];
   ; CHECK:       ld.param.v2.b32 {{{%r[0-9]+}}, {{%r[0-9]+}}}, [retval0+16];
@@ -566,8 +566,8 @@ define dso_local void @caller_St8x3(ptr nocapture noundef readonly byval(%struct
   ; CHECK:       )
   ; CHECK:       .param .align 16 .b8 param0[24];
   ; CHECK:       .param .align 16 .b8 retval0[24];
-  ; CHECK:       st.param.v2.b64 [param0],  {{{%rd[0-9]+}}, {{%rd[0-9]+}}};
-  ; CHECK:       st.param.b64    [param0+16], {{%rd[0-9]+}};
+  ; CHECK-DAG:   st.param.v2.b64 [param0],  {{{%rd[0-9]+}}, {{%rd[0-9]+}}};
+  ; CHECK-DAG:   st.param.b64    [param0+16], {{%rd[0-9]+}};
   ; CHECK:       call.uni (retval0), callee_St8x3, (param0);
   ; CHECK:       ld.param.v2.b64 {{{%rd[0-9]+}}, {{%rd[0-9]+}}}, [retval0];
   ; CHECK:       ld.param.b64    {{%rd[0-9]+}}, [retval0+16];

>From ca2239237cb847eef14e980d8af5ed4d327b8289 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Thu, 31 Jul 2025 05:05:21 +0000
Subject: [PATCH 3/3] address comments

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 22 ++++++++++++---------
 1 file changed, 13 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 3e48e61f8c6c8..124914d9e4e9d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -382,12 +382,16 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
   }
 }
 
+// We return an EVT that can hold N VTs
+// If the VT is a vector, the resulting EVT is a flat vector with the same
+// element type as VT's element type.
 static EVT getVectorizedVT(EVT VT, unsigned N, LLVMContext &C) {
   if (N == 1)
     return VT;
 
-  const unsigned PackingAmt = VT.isVector() ? VT.getVectorNumElements() : 1;
-  return EVT::getVectorVT(C, VT.getScalarType(), N * PackingAmt);
+  return VT.isVector() ? EVT::getVectorVT(C, VT.getScalarType(),
+                                          VT.getVectorNumElements() * N)
+                       : EVT::getVectorVT(C, VT, N);
 }
 
 static SDValue getExtractVectorizedValue(SDValue V, unsigned I, EVT VT,
@@ -407,9 +411,8 @@ static SDValue getExtractVectorizedValue(SDValue V, unsigned I, EVT VT,
 }
 
 template <typename T>
-static inline SDValue getBuildVectorizedValue(T GetElement, unsigned N,
-                                              const SDLoc &dl,
-                                              SelectionDAG &DAG) {
+static inline SDValue getBuildVectorizedValue(unsigned N, const SDLoc &dl,
+                                              SelectionDAG &DAG, T GetElement) {
   if (N == 1)
     return GetElement(0);
 
@@ -1650,9 +1653,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                                             ? MaybeAlign(std::nullopt)
                                             : commonAlignment(ArgAlign, Offset);
 
-        SDValue Val = getBuildVectorizedValue(
-            [&](unsigned K) { return GetStoredValue(J + K); }, NumElts, dl,
-            DAG);
+        SDValue Val =
+            getBuildVectorizedValue(NumElts, dl, DAG, [&](unsigned K) {
+              return GetStoredValue(J + K);
+            });
 
         SDValue StoreParam =
             DAG.getStore(ArgDeclare, dl, Val, Ptr,
@@ -3416,7 +3420,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
                                         : commonAlignment(RetAlign, Offsets[I]);
 
     SDValue Val = getBuildVectorizedValue(
-        [&](unsigned K) { return GetRetVal(I + K); }, NumElts, dl, DAG);
+        NumElts, dl, DAG, [&](unsigned K) { return GetRetVal(I + K); });
 
     SDValue Ptr =
         DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));



More information about the llvm-commits mailing list