[llvm] [NVPTX] Cleanup NVPTXLowerArgs, simplifying logic and improving alignment propagation (PR #180286)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 12 21:08:36 PST 2026


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/180286

>From ea8a4e35034b30130cc031362ce5373bdf3e5971 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Fri, 6 Feb 2026 21:28:17 +0000
Subject: [PATCH 1/3] [NVPTX] Cleanup NVPTXLowerArgs, simplifying logic and
 improving alignment propagation

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   |   2 +
 llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp      | 285 ++++++++----------
 llvm/test/CodeGen/NVPTX/bug21465.ll           |   2 +-
 .../CodeGen/NVPTX/lower-args-alignment.ll     |   6 +-
 .../CodeGen/NVPTX/lower-args-gridconstant.ll  |  42 +--
 llvm/test/CodeGen/NVPTX/lower-args.ll         |   4 +-
 llvm/test/CodeGen/NVPTX/lower-byval-args.ll   | 154 ++++++----
 7 files changed, 254 insertions(+), 241 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 755d270563786..8f1b70533a869 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -4115,6 +4115,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
 
       SDValue P;
       if (isKernelFunction(F)) {
+        assert(isParamGridConstant(Arg) && "ByVal argument must be lowered to "
+                                           "grid_constant by NVPTXLowerArgs");
         P = ArgSymbol;
         P.getNode()->setIROrder(Arg.getArgNo() + 1);
       } else {
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index e2bbe57c0085c..c9d761345925d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -143,8 +143,10 @@
 #include "llvm/Analysis/PtrUseVisitor.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/CodeGen/TargetPassConfig.h"
+#include "llvm/IR/Attributes.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/IntrinsicsNVPTX.h"
@@ -207,8 +209,7 @@ INITIALIZE_PASS_END(NVPTXLowerArgsLegacyPass, "nvptx-lower-args",
 // pointer in parameter AS.
 // For "escapes" (to memory, a function call, or a ptrtoint), cast the OldUse to
 // generic using cvta.param.
-static void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam,
-                             bool IsGridConstant) {
+static void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam) {
   Instruction *I = dyn_cast<Instruction>(OldUse->getUser());
   assert(I && "OldUse must be in an instruction");
   struct IP {
@@ -219,8 +220,7 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam,
   SmallVector<IP> ItemsToConvert = {{OldUse, I, Param}};
   SmallVector<Instruction *> InstructionsToDelete;
 
-  auto CloneInstInParamAS = [HasCvtaParam,
-                             IsGridConstant](const IP &I) -> Value * {
+  auto CloneInstInParamAS = [HasCvtaParam](const IP &I) -> Value * {
     if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) {
       LI->setOperand(0, I.NewParam);
       return LI;
@@ -285,28 +285,6 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam,
         if (SI->getFalseValue() == I.OldUse->get())
           SI->setFalseValue(ParamInGenericAS);
       }
-
-      // Escapes or writes can only use generic param pointers if
-      // __grid_constant__ is in effect.
-      if (IsGridConstant) {
-        if (auto *CI = dyn_cast<CallInst>(I.OldInstruction)) {
-          I.OldUse->set(ParamInGenericAS);
-          return CI;
-        }
-        if (auto *SI = dyn_cast<StoreInst>(I.OldInstruction)) {
-          // byval address is being stored, cast it to generic
-          if (SI->getValueOperand() == I.OldUse->get())
-            SI->setOperand(0, ParamInGenericAS);
-          return SI;
-        }
-        if (auto *PI = dyn_cast<PtrToIntInst>(I.OldInstruction)) {
-          if (PI->getPointerOperand() == I.OldUse->get())
-            PI->setOperand(0, ParamInGenericAS);
-          return PI;
-        }
-        // TODO: iIf we allow stores, we should allow memcpy/memset to
-        // parameter, too.
-      }
     }
 
     llvm_unreachable("Unsupported instruction");
@@ -338,34 +316,36 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam,
     I->eraseFromParent();
 }
 
-// Adjust alignment of arguments passed byval in .param address space. We can
-// increase alignment of such arguments in a way that ensures that we can
-// effectively vectorize their loads. We should also traverse all loads from
-// byval pointer and adjust their alignment, if those were using known offset.
-// Such alignment changes must be conformed with parameter store and load in
-// NVPTXTargetLowering::LowerCall.
-static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
-                                    const NVPTXTargetLowering *TLI) {
-  Function *Func = Arg->getParent();
-  Type *StructType = Arg->getParamByValType();
-  const DataLayout &DL = Func->getDataLayout();
+static Align setByValParamAlign(Argument *Arg, const NVPTXTargetLowering *TLI) {
+  Function *F = Arg->getParent();
+  Type *ByValType = Arg->getParamByValType();
+  const DataLayout &DL = F->getDataLayout();
 
-  const Align NewArgAlign =
-      TLI->getFunctionParamOptimizedAlign(Func, StructType, DL);
-  const Align CurArgAlign = Arg->getParamAlign().valueOrOne();
+  const Align OptimizedAlign =
+      TLI->getFunctionParamOptimizedAlign(F, ByValType, DL);
+  const Align CurrentAlign = Arg->getParamAlign().valueOrOne();
 
-  if (CurArgAlign >= NewArgAlign)
-    return;
+  if (CurrentAlign >= OptimizedAlign)
+    return CurrentAlign;
 
-  LLVM_DEBUG(dbgs() << "Try to use alignment " << NewArgAlign.value()
-                    << " instead of " << CurArgAlign.value() << " for " << *Arg
+  LLVM_DEBUG(dbgs() << "Try to use alignment " << OptimizedAlign.value()
+                    << " instead of " << CurrentAlign.value() << " for " << *Arg
                     << '\n');
 
-  auto NewAlignAttr =
-      Attribute::getWithAlignment(Func->getContext(), NewArgAlign);
   Arg->removeAttr(Attribute::Alignment);
-  Arg->addAttr(NewAlignAttr);
+  Arg->addAttr(Attribute::getWithAlignment(F->getContext(), OptimizedAlign));
 
+  return OptimizedAlign;
+}
+
+// Adjust alignment of arguments passed byval in .param address space. We can
+// increase alignment of such arguments in a way that ensures that we can
+// effectively vectorize their loads. We should also traverse all loads from
+// byval pointer and adjust their alignment, if those were using known offset.
+// Such alignment changes must be conformed with parameter store and load in
+// NVPTXTargetLowering::LowerCall.
+static void propagateAlignmentToLoads(Value *Val, Align NewAlign,
+                                      const DataLayout &DL) {
   struct Load {
     LoadInst *Inst;
     uint64_t Offset;
@@ -378,7 +358,7 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
 
   SmallVector<Load> Loads;
   std::queue<LoadContext> Worklist;
-  Worklist.push({ArgInParamAS, 0});
+  Worklist.push({Val, 0});
 
   while (!Worklist.empty()) {
     LoadContext Ctx = Worklist.front();
@@ -406,7 +386,7 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
   }
 
   for (Load &CurLoad : Loads) {
-    Align NewLoadAlign(std::gcd(NewArgAlign.value(), CurLoad.Offset));
+    Align NewLoadAlign = commonAlignment(NewAlign, CurLoad.Offset);
     Align CurLoadAlign = CurLoad.Inst->getAlign();
     CurLoad.Inst->setAlignment(std::max(NewLoadAlign, CurLoadAlign));
   }
@@ -425,19 +405,19 @@ static CallInst *createNVVMInternalAddrspaceWrap(IRBuilder<> &IRB,
     ArgInParam->addRetAttr(
         Attribute::getWithAlignment(ArgInParam->getContext(), *ParamAlign));
 
+  Arg.addAttr(Attribute::get(Arg.getContext(), "nvvm.grid_constant"));
+  Arg.addAttr(Attribute::ReadOnly);
+
   return ArgInParam;
 }
 
 namespace {
 struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
   using Base = PtrUseVisitor<ArgUseChecker>;
-
-  bool IsGridConstant;
   // Set of phi/select instructions using the Arg
   SmallPtrSet<Instruction *, 4> Conditionals;
 
-  ArgUseChecker(const DataLayout &DL, bool IsGridConstant)
-      : PtrUseVisitor(DL), IsGridConstant(IsGridConstant) {}
+  ArgUseChecker(const DataLayout &DL) : PtrUseVisitor(DL) {}
 
   PtrInfo visitArgPtr(Argument &A) {
     assert(A.getType()->isPointerTy());
@@ -445,7 +425,6 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
     IsOffsetKnown = false;
     Offset = APInt(IntIdxTy->getBitWidth(), 0);
     PI.reset();
-    Conditionals.clear();
 
     LLVM_DEBUG(dbgs() << "Checking Argument " << A << "\n");
     // Enqueue the uses of this pointer.
@@ -458,8 +437,6 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
       UseToVisit ToVisit = Worklist.pop_back_val();
       U = ToVisit.UseAndIsOffsetKnown.getPointer();
       Instruction *I = cast<Instruction>(U->getUser());
-      if (isa<PHINode>(I) || isa<SelectInst>(I))
-        Conditionals.insert(I);
       LLVM_DEBUG(dbgs() << "Processing " << *I << "\n");
       Base::visit(I);
     }
@@ -478,10 +455,8 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
     // Storing the pointer escapes it.
     if (U->get() == SI.getValueOperand())
       return PI.setEscapedAndAborted(&SI);
-    // Writes to the pointer are UB w/ __grid_constant__, but do not force a
-    // copy.
-    if (!IsGridConstant)
-      return PI.setAborted(&SI);
+
+    PI.setAborted(&SI);
   }
 
   void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) {
@@ -491,120 +466,119 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
     Base::visitAddrSpaceCastInst(ASC);
   }
 
-  void visitPtrToIntInst(PtrToIntInst &I) {
-    if (IsGridConstant)
-      return;
-    Base::visitPtrToIntInst(I);
-  }
+  void visitPtrToIntInst(PtrToIntInst &I) { Base::visitPtrToIntInst(I); }
+
   void visitPHINodeOrSelectInst(Instruction &I) {
     assert(isa<PHINode>(I) || isa<SelectInst>(I));
+    enqueueUsers(I);
+    Conditionals.insert(&I);
   }
   // PHI and select just pass through the pointers.
-  void visitPHINode(PHINode &PN) { enqueueUsers(PN); }
-  void visitSelectInst(SelectInst &SI) { enqueueUsers(SI); }
+  void visitPHINode(PHINode &PN) { visitPHINodeOrSelectInst(PN); }
+  void visitSelectInst(SelectInst &SI) { visitPHINodeOrSelectInst(SI); }
 
+  // memcpy/memmove are OK when the pointer is source. We can convert them to
+  // AS-specific memcpy.
   void visitMemTransferInst(MemTransferInst &II) {
-    if (*U == II.getRawDest() && !IsGridConstant)
+    if (*U == II.getRawDest())
       PI.setAborted(&II);
-    // memcpy/memmove are OK when the pointer is source. We can convert them to
-    // AS-specific memcpy.
   }
 
-  void visitMemSetInst(MemSetInst &II) {
-    if (!IsGridConstant)
-      PI.setAborted(&II);
-  }
+  void visitMemSetInst(MemSetInst &II) { PI.setAborted(&II); }
 }; // struct ArgUseChecker
 
 void copyByValParam(Function &F, Argument &Arg) {
   LLVM_DEBUG(dbgs() << "Creating a local copy of " << Arg << "\n");
-  // Otherwise we have to create a temporary copy.
-  BasicBlock::iterator FirstInst = F.getEntryBlock().begin();
-  Type *StructType = Arg.getParamByValType();
+  Type *ByValType = Arg.getParamByValType();
   const DataLayout &DL = F.getDataLayout();
-  IRBuilder<> IRB(&*FirstInst);
-  AllocaInst *AllocA = IRB.CreateAlloca(StructType, nullptr, Arg.getName());
+  IRBuilder<> IRB(&F.getEntryBlock().front());
+  AllocaInst *AllocA = IRB.CreateAlloca(ByValType, nullptr, Arg.getName());
   // Set the alignment to alignment of the byval parameter. This is because,
   // later load/stores assume that alignment, and we are going to replace
   // the use of the byval parameter with this alloca instruction.
   AllocA->setAlignment(
-      Arg.getParamAlign().value_or(DL.getPrefTypeAlign(StructType)));
+      Arg.getParamAlign().value_or(DL.getPrefTypeAlign(ByValType)));
   Arg.replaceAllUsesWith(AllocA);
 
-  CallInst *ArgInParam = createNVVMInternalAddrspaceWrap(IRB, Arg);
+  Value *ArgInParamAS = createNVVMInternalAddrspaceWrap(IRB, Arg);
 
   // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
   // addrspacecast preserves alignment.  Since params are constant, this load
   // is definitely not volatile.
   const auto ArgSize = *AllocA->getAllocationSize(DL);
-  IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParam, AllocA->getAlign(),
+  IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParamAS, AllocA->getAlign(),
                    ArgSize);
 }
 } // namespace
 
+static bool argIsProcessed(Argument *Arg) {
+  if (Arg->use_empty())
+    return true;
+
+  // If the argument is already wrapped, it was processed by this pass before.
+  if (Arg->hasOneUse())
+    if (const auto *II = dyn_cast<IntrinsicInst>(*Arg->user_begin()))
+      if (II->getIntrinsicID() == Intrinsic::nvvm_internal_addrspace_wrap)
+        return true;
+
+  return false;
+}
+
 static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
-  Function *Func = Arg->getParent();
-  assert(isKernelFunction(*Func));
-  const bool HasCvtaParam = TM.getSubtargetImpl(*Func)->hasCvtaParam();
-  const bool IsGridConstant = HasCvtaParam && isParamGridConstant(*Arg);
-  const DataLayout &DL = Func->getDataLayout();
-  BasicBlock::iterator FirstInst = Func->getEntryBlock().begin();
-  [[maybe_unused]] Type *StructType = Arg->getParamByValType();
-  assert(StructType && "Missing byval type");
-
-  ArgUseChecker AUC(DL, IsGridConstant);
+  Function *F = Arg->getParent();
+  assert(isKernelFunction(*F));
+  const NVPTXSubtarget *ST = TM.getSubtargetImpl(*F);
+  const bool HasCvtaParam = ST->hasCvtaParam();
+
+  const DataLayout &DL = F->getDataLayout();
+  IRBuilder<> IRB(&F->getEntryBlock().front());
+
+  if (argIsProcessed(Arg))
+    return;
+
+  const Align NewArgAlign = setByValParamAlign(Arg, ST->getTargetLowering());
+
+  // (1) First check the easy case, if were able to trace through all the uses
+  // and we can convert them all to param AS, then we'll do this.
+  ArgUseChecker AUC(DL);
   ArgUseChecker::PtrInfo PI = AUC.visitArgPtr(*Arg);
-  bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());
-  // Easy case, accessing parameter directly is fine.
+  const bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());
   if (ArgUseIsReadOnly && AUC.Conditionals.empty()) {
     // Convert all loads and intermediate operations to use parameter AS and
     // skip creation of a local copy of the argument.
     SmallVector<Use *, 16> UsesToUpdate(llvm::make_pointer_range(Arg->uses()));
-
-    IRBuilder<> IRB(&*FirstInst);
-    CallInst *ArgInParamAS = createNVVMInternalAddrspaceWrap(IRB, *Arg);
-
+    Value *ArgInParamAS = createNVVMInternalAddrspaceWrap(IRB, *Arg);
     for (Use *U : UsesToUpdate)
-      convertToParamAS(U, ArgInParamAS, HasCvtaParam, IsGridConstant);
-    LLVM_DEBUG(dbgs() << "No need to copy or cast " << *Arg << "\n");
-
-    const auto *TLI =
-        cast<NVPTXTargetLowering>(TM.getSubtargetImpl()->getTargetLowering());
-
-    adjustByValArgAlignment(Arg, ArgInParamAS, TLI);
+      convertToParamAS(U, ArgInParamAS, HasCvtaParam);
 
+    propagateAlignmentToLoads(ArgInParamAS, NewArgAlign, DL);
     return;
   }
 
-  // We can't access byval arg directly and need a pointer. on sm_70+ we have
-  // ability to take a pointer to the argument without making a local copy.
-  // However, we're still not allowed to write to it. If the user specified
-  // `__grid_constant__` for the argument, we'll consider escaped pointer as
-  // read-only.
-  if (IsGridConstant || (HasCvtaParam && ArgUseIsReadOnly)) {
+  // (2) If the argument is grid constant, we get to use the pointer directly.
+  if (HasCvtaParam && (ArgUseIsReadOnly || isParamGridConstant(*Arg))) {
     LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");
-    // Replace all argument pointer uses (which might include a device function
-    // call) with a cast to the generic address space using cvta.param
-    // instruction, which avoids a local copy.
-    IRBuilder<> IRB(&Func->getEntryBlock().front());
 
     // Cast argument to param address space. Because the backend will emit the
     // argument already in the param address space, we need to use the noop
     // intrinsic, this had the added benefit of preventing other optimizations
     // from folding away this pair of addrspacecasts.
-    auto *ParamSpaceArg = createNVVMInternalAddrspaceWrap(IRB, *Arg);
+    Instruction *ArgInParamAS = createNVVMInternalAddrspaceWrap(IRB, *Arg);
 
     // Cast param address to generic address space.
     Value *GenericArg = IRB.CreateAddrSpaceCast(
-        ParamSpaceArg, IRB.getPtrTy(ADDRESS_SPACE_GENERIC),
+        ArgInParamAS, IRB.getPtrTy(ADDRESS_SPACE_GENERIC),
         Arg->getName() + ".gen");
 
     Arg->replaceAllUsesWith(GenericArg);
 
     // Do not replace Arg in the cast to param space
-    ParamSpaceArg->setOperand(0, Arg);
-  } else
-    copyByValParam(*Func, *Arg);
+    ArgInParamAS->setOperand(0, Arg);
+    return;
+  }
+
+  // (3) Otherwise we have to create a copy of the argument in local memory.
+  copyByValParam(*F, *Arg);
 }
 
 static void markPointerAsAS(Value *Ptr, const unsigned AS) {
@@ -636,6 +610,15 @@ static void markPointerAsGlobal(Value *Ptr) {
   markPointerAsAS(Ptr, ADDRESS_SPACE_GLOBAL);
 }
 
+static void handleIntToPtr(Value &V) {
+  if (!all_of(V.users(), [](User *U) { return isa<IntToPtrInst>(U); }))
+    return;
+
+  SmallVector<User *, 16> UsersToUpdate(V.users());
+  for (User *U : UsersToUpdate)
+    markPointerAsGlobal(U);
+}
+
 // =============================================================================
 // Main function for this pass.
 // =============================================================================
@@ -644,44 +627,37 @@ static bool runOnKernelFunction(const NVPTXTargetMachine &TM, Function &F) {
   // integers, followed by intotoptr. We may want to mark those as global, too,
   // but only if the loaded integer is used exclusively for conversion to a
   // pointer with inttoptr.
-  auto HandleIntToPtr = [](Value &V) {
-    if (llvm::all_of(V.users(), [](User *U) { return isa<IntToPtrInst>(U); })) {
-      SmallVector<User *, 16> UsersToUpdate(V.users());
-      for (User *U : UsersToUpdate)
-        markPointerAsGlobal(U);
-    }
-  };
   if (TM.getDrvInterface() == NVPTX::CUDA) {
     // Mark pointers in byval structs as global.
-    for (auto &B : F) {
-      for (auto &I : B) {
-        if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
-          if (LI->getType()->isPointerTy() || LI->getType()->isIntegerTy()) {
-            Value *UO = getUnderlyingObject(LI->getPointerOperand());
-            if (Argument *Arg = dyn_cast<Argument>(UO)) {
-              if (Arg->hasByValAttr()) {
-                // LI is a load from a pointer within a byval kernel parameter.
-                if (LI->getType()->isPointerTy())
-                  markPointerAsGlobal(LI);
-                else
-                  HandleIntToPtr(*LI);
-              }
-            }
+    for (auto &I : instructions(F)) {
+      auto *LI = dyn_cast<LoadInst>(&I);
+      if (!LI)
+        continue;
+
+      if (LI->getType()->isPointerTy() || LI->getType()->isIntegerTy()) {
+        Value *UO = getUnderlyingObject(LI->getPointerOperand());
+        if (Argument *Arg = dyn_cast<Argument>(UO)) {
+          if (Arg->hasByValAttr()) {
+            // LI is a load from a pointer within a byval kernel parameter.
+            if (LI->getType()->isPointerTy())
+              markPointerAsGlobal(LI);
+            else
+              handleIntToPtr(*LI);
           }
         }
       }
     }
+
+    for (Argument &Arg : F.args())
+      if (Arg.getType()->isIntegerTy())
+        handleIntToPtr(Arg);
   }
 
   LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n");
-  for (Argument &Arg : F.args()) {
-    if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) {
+  for (Argument &Arg : F.args())
+    if (Arg.hasByValAttr())
       handleByValParam(TM, &Arg);
-    } else if (Arg.getType()->isIntegerTy() &&
-               TM.getDrvInterface() == NVPTX::CUDA) {
-      HandleIntToPtr(Arg);
-    }
-  }
+
   return true;
 }
 
@@ -689,12 +665,14 @@ static bool runOnKernelFunction(const NVPTXTargetMachine &TM, Function &F) {
 static bool runOnDeviceFunction(const NVPTXTargetMachine &TM, Function &F) {
   LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n");
 
-  const auto *TLI =
-      cast<NVPTXTargetLowering>(TM.getSubtargetImpl()->getTargetLowering());
+  const NVPTXTargetLowering *TLI = TM.getSubtargetImpl()->getTargetLowering();
+  const DataLayout &DL = F.getDataLayout();
 
   for (Argument &Arg : F.args())
-    if (Arg.getType()->isPointerTy() && Arg.hasByValAttr())
-      adjustByValArgAlignment(&Arg, &Arg, TLI);
+    if (Arg.hasByValAttr()) {
+      const Align NewArgAlign = setByValParamAlign(&Arg, TLI);
+      propagateAlignmentToLoads(&Arg, NewArgAlign, DL);
+    }
 
   return true;
 }
@@ -718,8 +696,7 @@ static bool copyFunctionByValArgs(Function &F) {
   bool Changed = false;
   if (isKernelFunction(F)) {
     for (Argument &Arg : F.args())
-      if (Arg.getType()->isPointerTy() && Arg.hasByValAttr() &&
-          !isParamGridConstant(Arg)) {
+      if (Arg.hasByValAttr() && !isParamGridConstant(Arg)) {
         copyByValParam(F, Arg);
         Changed = true;
       }
diff --git a/llvm/test/CodeGen/NVPTX/bug21465.ll b/llvm/test/CodeGen/NVPTX/bug21465.ll
index 79b0dbcf6494c..8730b64d42f58 100644
--- a/llvm/test/CodeGen/NVPTX/bug21465.ll
+++ b/llvm/test/CodeGen/NVPTX/bug21465.ll
@@ -12,7 +12,7 @@ define ptx_kernel void @_Z11TakesStruct1SPi(ptr byval(%struct.S) nocapture reado
 entry:
 ; CHECK-LABEL: @_Z11TakesStruct1SPi
 ; PTX-LABEL: .visible .entry _Z11TakesStruct1SPi(
-; CHECK: call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr %input)
+; CHECK: call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr %input)
   %b = getelementptr inbounds %struct.S, ptr %input, i64 0, i32 1
   %0 = load i32, ptr %b, align 4
 ; PTX-NOT: ld.param.b32 {{%r[0-9]+}}, [{{%rd[0-9]+}}]
diff --git a/llvm/test/CodeGen/NVPTX/lower-args-alignment.ll b/llvm/test/CodeGen/NVPTX/lower-args-alignment.ll
index 2051f6305cc03..fe4d27b02c7fb 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args-alignment.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args-alignment.ll
@@ -11,7 +11,7 @@ target triple = "nvptx64-nvidia-cuda"
 %struct.S1 = type { i32, i32, i32, i32 }
 define ptx_kernel i32 @test_align8(ptr noundef readonly byval(%struct.S1) align 8 captures(none) %params) {
 ; CHECK-LABEL: define ptx_kernel i32 @test_align8(
-; CHECK-SAME: ptr noundef readonly byval([[STRUCT_S1:%.*]]) align 8 captures(none) [[PARAMS:%.*]]) {
+; CHECK-SAME: ptr noundef readonly byval([[STRUCT_S1:%.*]]) align 8 captures(none) "nvvm.grid_constant" [[PARAMS:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
 ; CHECK-NEXT:    [[TMP0:%.*]] = call align 8 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[PARAMS]])
 ; CHECK-NEXT:    [[LOAD:%.*]] = load i32, ptr addrspace(101) [[TMP0]], align 8
@@ -24,9 +24,9 @@ entry:
 
 define ptx_kernel i32 @test_align1(ptr noundef readonly byval(%struct.S1) align 1 captures(none) %params) {
 ; CHECK-LABEL: define ptx_kernel i32 @test_align1(
-; CHECK-SAME: ptr noundef readonly byval([[STRUCT_S1:%.*]]) align 4 captures(none) [[PARAMS:%.*]]) {
+; CHECK-SAME: ptr noundef readonly byval([[STRUCT_S1:%.*]]) align 4 captures(none) "nvvm.grid_constant" [[PARAMS:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[TMP0:%.*]] = call align 1 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[PARAMS]])
+; CHECK-NEXT:    [[TMP0:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[PARAMS]])
 ; CHECK-NEXT:    [[LOAD:%.*]] = load i32, ptr addrspace(101) [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[LOAD]]
 ;
diff --git a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
index 01ab47145940c..ab76de3326f63 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
@@ -1,6 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
 ; RUN: opt < %s -S -nvptx-lower-args --mtriple nvptx64-nvidia-cuda -mcpu=sm_70 -mattr=+ptx77 | FileCheck %s --check-prefixes OPT
-; RUN: llc < %s --mtriple nvptx64-nvidia-cuda -mcpu=sm_70 -mattr=+ptx77 | FileCheck %s --check-prefixes PTX
+; RUN: llc < %s --mtriple nvptx64-nvidia-cuda -mcpu=sm_70 -mattr=+ptx77 -O1 | FileCheck %s --check-prefixes PTX
 
 %struct.uint4 = type { i32, i32, i32, i32 }
 
@@ -71,7 +71,7 @@ define ptx_kernel void @grid_const_int(ptr byval(i32) align 4 "nvvm.grid_constan
 ; PTX-NEXT:    st.global.b32 [%rd2], %r3;
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_int(
-; OPT-SAME: ptr byval(i32) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], i32 [[INPUT2:%.*]], ptr [[OUT:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr readonly byval(i32) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], i32 [[INPUT2:%.*]], ptr [[OUT:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
 ; OPT-NEXT:    [[INPUT11:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT1]])
 ; OPT-NEXT:    [[TMP:%.*]] = load i32, ptr addrspace(101) [[INPUT11]], align 4
 ; OPT-NEXT:    [[ADD:%.*]] = add i32 [[TMP]], [[INPUT2]]
@@ -100,7 +100,7 @@ define ptx_kernel void @grid_const_struct(ptr byval(%struct.s) align 4 "nvvm.gri
 ; PTX-NEXT:    st.global.b32 [%rd2], %r3;
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_struct(
-; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT:%.*]], ptr [[OUT:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr readonly byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT:%.*]], ptr [[OUT:%.*]]) #[[ATTR0]] {
 ; OPT-NEXT:    [[INPUT1:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
 ; OPT-NEXT:    [[GEP13:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr addrspace(101) [[INPUT1]], i32 0, i32 0
 ; OPT-NEXT:    [[GEP22:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr addrspace(101) [[INPUT1]], i32 0, i32 1
@@ -136,7 +136,7 @@ define ptx_kernel void @grid_const_escape(ptr byval(%struct.s) align 4 "nvvm.gri
 ; PTX-NEXT:    } // callseq 0
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_escape(
-; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr readonly byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT:%.*]]) #[[ATTR0]] {
 ; OPT-NEXT:    [[TMP1:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
 ; OPT-NEXT:    [[INPUT_PARAM_GEN:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
 ; OPT-NEXT:    [[CALL:%.*]] = call i32 @escape(ptr [[INPUT_PARAM_GEN]])
@@ -179,7 +179,7 @@ define ptx_kernel void @multiple_grid_const_escape(ptr byval(%struct.s) align 4
 ; PTX-NEXT:    } // callseq 1
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @multiple_grid_const_escape(
-; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT:%.*]], i32 [[A:%.*]], ptr byval(i32) align 4 "nvvm.grid_constant" [[B:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr readonly byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT:%.*]], i32 [[A:%.*]], ptr readonly byval(i32) align 4 "nvvm.grid_constant" [[B:%.*]]) #[[ATTR0]] {
 ; OPT-NEXT:    [[TMP1:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[B]])
 ; OPT-NEXT:    [[B_PARAM_GEN:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
 ; OPT-NEXT:    [[TMP2:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
@@ -207,7 +207,7 @@ define ptx_kernel void @grid_const_memory_escape(ptr byval(%struct.s) align 4 "n
 ; PTX-NEXT:    st.global.b64 [%rd3], %rd4;
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_memory_escape(
-; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT:%.*]], ptr [[ADDR:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr readonly byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT:%.*]], ptr [[ADDR:%.*]]) #[[ATTR0]] {
 ; OPT-NEXT:    [[TMP1:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
 ; OPT-NEXT:    [[INPUT1:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
 ; OPT-NEXT:    store ptr [[INPUT1]], ptr [[ADDR]], align 8
@@ -234,7 +234,7 @@ define ptx_kernel void @grid_const_inlineasm_escape(ptr byval(%struct.s) align 4
 ; PTX-NEXT:    ret;
 ; PTX-NOT      .local
 ; OPT-LABEL: define ptx_kernel void @grid_const_inlineasm_escape(
-; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT:%.*]], ptr [[RESULT:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr readonly byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT:%.*]], ptr [[RESULT:%.*]]) #[[ATTR0]] {
 ; OPT-NEXT:    [[TMP1:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
 ; OPT-NEXT:    [[INPUT1:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
 ; OPT-NEXT:    [[TMPPTR1:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT1]], i32 0, i32 0
@@ -273,8 +273,8 @@ define ptx_kernel void @grid_const_partial_escape(ptr byval(i32) "nvvm.grid_cons
 ; PTX-NEXT:    } // callseq 2
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_partial_escape(
-; OPT-SAME: ptr byval(i32) "nvvm.grid_constant" [[INPUT:%.*]], ptr [[OUTPUT:%.*]]) #[[ATTR0]] {
-; OPT-NEXT:    [[TMP1:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
+; OPT-SAME: ptr readonly byval(i32) align 4 "nvvm.grid_constant" [[INPUT:%.*]], ptr [[OUTPUT:%.*]]) #[[ATTR0]] {
+; OPT-NEXT:    [[TMP1:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
 ; OPT-NEXT:    [[INPUT1_GEN:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
 ; OPT-NEXT:    [[VAL1:%.*]] = load i32, ptr [[INPUT1_GEN]], align 4
 ; OPT-NEXT:    [[TWICE:%.*]] = add i32 [[VAL1]], [[VAL1]]
@@ -314,8 +314,8 @@ define ptx_kernel i32 @grid_const_partial_escapemem(ptr byval(%struct.s) "nvvm.g
 ; PTX-NEXT:    st.param.b32 [func_retval0], %r3;
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel i32 @grid_const_partial_escapemem(
-; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) "nvvm.grid_constant" [[INPUT:%.*]], ptr [[OUTPUT:%.*]]) #[[ATTR0]] {
-; OPT-NEXT:    [[TMP1:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
+; OPT-SAME: ptr readonly byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT:%.*]], ptr [[OUTPUT:%.*]]) #[[ATTR0]] {
+; OPT-NEXT:    [[TMP1:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
 ; OPT-NEXT:    [[INPUT1:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
 ; OPT-NEXT:    [[PTR1:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT1]], i32 0, i32 0
 ; OPT-NEXT:    [[VAL1:%.*]] = load i32, ptr [[PTR1]], align 4
@@ -356,7 +356,7 @@ define ptx_kernel void @grid_const_phi(ptr byval(%struct.s) align 4 "nvvm.grid_c
 ; PTX-NEXT:    st.global.b32 [%rd1], %r2;
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_phi(
-; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr [[INOUT:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr readonly byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr [[INOUT:%.*]]) #[[ATTR0]] {
 ; OPT-NEXT:    [[TMP1:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT1]])
 ; OPT-NEXT:    [[INPUT1_PARAM_GEN:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
 ; OPT-NEXT:    [[VAL:%.*]] = load i32, ptr [[INOUT]], align 4
@@ -413,8 +413,8 @@ define ptx_kernel void @grid_const_phi_ngc(ptr byval(%struct.s) align 4 "nvvm.gr
 ; PTX-NEXT:    st.global.b32 [%rd1], %r2;
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_phi_ngc(
-; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr byval([[STRUCT_S]]) [[INPUT2:%.*]], ptr [[INOUT:%.*]]) #[[ATTR0]] {
-; OPT-NEXT:    [[TMP1:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
+; OPT-SAME: ptr readonly byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr readonly byval([[STRUCT_S]]) align 4 "nvvm.grid_constant" [[INPUT2:%.*]], ptr [[INOUT:%.*]]) #[[ATTR0]] {
+; OPT-NEXT:    [[TMP1:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
 ; OPT-NEXT:    [[INPUT2_PARAM_GEN:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
 ; OPT-NEXT:    [[TMP2:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT1]])
 ; OPT-NEXT:    [[INPUT1_PARAM_GEN:%.*]] = addrspacecast ptr addrspace(101) [[TMP2]] to ptr
@@ -468,8 +468,8 @@ define ptx_kernel void @grid_const_select(ptr byval(i32) align 4 "nvvm.grid_cons
 ; PTX-NEXT:    st.global.b32 [%rd3], %r2;
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_select(
-; OPT-SAME: ptr byval(i32) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr byval(i32) [[INPUT2:%.*]], ptr [[INOUT:%.*]]) #[[ATTR0]] {
-; OPT-NEXT:    [[TMP1:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
+; OPT-SAME: ptr readonly byval(i32) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr readonly byval(i32) align 4 "nvvm.grid_constant" [[INPUT2:%.*]], ptr [[INOUT:%.*]]) #[[ATTR0]] {
+; OPT-NEXT:    [[TMP1:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
 ; OPT-NEXT:    [[INPUT2_PARAM_GEN:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
 ; OPT-NEXT:    [[TMP2:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT1]])
 ; OPT-NEXT:    [[INPUT1_PARAM_GEN:%.*]] = addrspacecast ptr addrspace(101) [[TMP2]] to ptr
@@ -495,17 +495,17 @@ define ptx_kernel i32 @grid_const_ptrtoint(ptr byval(i32) "nvvm.grid_constant" %
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
 ; PTX-NEXT:    mov.b64 %rd1, grid_const_ptrtoint_param_0;
-; PTX-NEXT:    ld.param.b32 %r1, [grid_const_ptrtoint_param_0];
 ; PTX-NEXT:    cvta.param.u64 %rd2, %rd1;
+; PTX-NEXT:    ld.param.b32 %r1, [grid_const_ptrtoint_param_0];
 ; PTX-NEXT:    cvt.u32.u64 %r2, %rd2;
 ; PTX-NEXT:    add.s32 %r3, %r1, %r2;
 ; PTX-NEXT:    st.param.b32 [func_retval0], %r3;
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel i32 @grid_const_ptrtoint(
-; OPT-SAME: ptr byval(i32) align 4 "nvvm.grid_constant" [[INPUT:%.*]]) #[[ATTR0]] {
-; OPT-NEXT:    [[INPUT2:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
-; OPT-NEXT:    [[INPUT3:%.*]] = load i32, ptr addrspace(101) [[INPUT2]], align 4
+; OPT-SAME: ptr readonly byval(i32) align 4 "nvvm.grid_constant" [[INPUT:%.*]]) #[[ATTR0]] {
+; OPT-NEXT:    [[INPUT2:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
 ; OPT-NEXT:    [[INPUT1:%.*]] = addrspacecast ptr addrspace(101) [[INPUT2]] to ptr
+; OPT-NEXT:    [[INPUT3:%.*]] = load i32, ptr [[INPUT1]], align 4
 ; OPT-NEXT:    [[PTRVAL:%.*]] = ptrtoint ptr [[INPUT1]] to i32
 ; OPT-NEXT:    [[KEEPALIVE:%.*]] = add i32 [[INPUT3]], [[PTRVAL]]
 ; OPT-NEXT:    ret i32 [[KEEPALIVE]]
@@ -519,7 +519,7 @@ declare void @device_func(ptr byval(i32) align 4)
 
 define ptx_kernel void @test_forward_byval_arg(ptr byval(i32) align 4 "nvvm.grid_constant" %input) {
 ; OPT-LABEL: define ptx_kernel void @test_forward_byval_arg(
-; OPT-SAME: ptr byval(i32) align 4 "nvvm.grid_constant" [[INPUT:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr readonly byval(i32) align 4 "nvvm.grid_constant" [[INPUT:%.*]]) #[[ATTR0]] {
 ; OPT-NEXT:    [[INPUT_PARAM:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
 ; OPT-NEXT:    [[INPUT_PARAM_GEN:%.*]] = addrspacecast ptr addrspace(101) [[INPUT_PARAM]] to ptr
 ; OPT-NEXT:    call void @device_func(ptr byval(i32) align 4 [[INPUT_PARAM_GEN]])
diff --git a/llvm/test/CodeGen/NVPTX/lower-args.ll b/llvm/test/CodeGen/NVPTX/lower-args.ll
index b4a51035c6610..c0bc34602c29c 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args.ll
@@ -199,7 +199,7 @@ define ptx_kernel void @ptr_as_int(i64 noundef %i, i32 noundef %v) {
 
 define ptx_kernel void @ptr_as_int_aggr(ptr nocapture noundef readonly byval(%struct.S) align 8 %s, i32 noundef %v) {
 ; IRC-LABEL: define ptx_kernel void @ptr_as_int_aggr(
-; IRC-SAME: ptr noundef readonly byval([[STRUCT_S:%.*]]) align 8 captures(none) [[S:%.*]], i32 noundef [[V:%.*]]) {
+; IRC-SAME: ptr noundef readonly byval([[STRUCT_S:%.*]]) align 8 captures(none) "nvvm.grid_constant" [[S:%.*]], i32 noundef [[V:%.*]]) {
 ; IRC-NEXT:    [[S3:%.*]] = call align 8 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[S]])
 ; IRC-NEXT:    [[I:%.*]] = load i64, ptr addrspace(101) [[S3]], align 8
 ; IRC-NEXT:    [[P:%.*]] = inttoptr i64 [[I]] to ptr
@@ -209,7 +209,7 @@ define ptx_kernel void @ptr_as_int_aggr(ptr nocapture noundef readonly byval(%st
 ; IRC-NEXT:    ret void
 ;
 ; IRO-LABEL: define ptx_kernel void @ptr_as_int_aggr(
-; IRO-SAME: ptr noundef readonly byval([[STRUCT_S:%.*]]) align 8 captures(none) [[S:%.*]], i32 noundef [[V:%.*]]) {
+; IRO-SAME: ptr noundef readonly byval([[STRUCT_S:%.*]]) align 8 captures(none) "nvvm.grid_constant" [[S:%.*]], i32 noundef [[V:%.*]]) {
 ; IRO-NEXT:    [[S1:%.*]] = call align 8 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[S]])
 ; IRO-NEXT:    [[I:%.*]] = load i64, ptr addrspace(101) [[S1]], align 8
 ; IRO-NEXT:    [[P:%.*]] = inttoptr i64 [[I]] to ptr
diff --git a/llvm/test/CodeGen/NVPTX/lower-byval-args.ll b/llvm/test/CodeGen/NVPTX/lower-byval-args.ll
index ca2914a2e8043..31dddd4e2784e 100644
--- a/llvm/test/CodeGen/NVPTX/lower-byval-args.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-byval-args.ll
@@ -30,7 +30,7 @@ declare void @llvm.memset.p0.i64(ptr nocapture writeonly, i8, i64, i1 immarg) #2
 ; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
 define dso_local ptx_kernel void @read_only(ptr nocapture noundef writeonly %out, ptr nocapture noundef readonly byval(%struct.S) align 4 %s) local_unnamed_addr #0 {
 ; LOWER-ARGS-LABEL: define dso_local ptx_kernel void @read_only(
-; LOWER-ARGS-SAME: ptr noundef writeonly captures(none) [[OUT:%.*]], ptr noundef readonly byval([[STRUCT_S:%.*]]) align 4 captures(none) [[S:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+; LOWER-ARGS-SAME: ptr noundef writeonly captures(none) [[OUT:%.*]], ptr noundef readonly byval([[STRUCT_S:%.*]]) align 4 captures(none) "nvvm.grid_constant" [[S:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 ; LOWER-ARGS-NEXT:  [[ENTRY:.*:]]
 ; LOWER-ARGS-NEXT:    [[S3:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[S]])
 ; LOWER-ARGS-NEXT:    [[I:%.*]] = load i32, ptr addrspace(101) [[S3]], align 4
@@ -64,7 +64,7 @@ entry:
 ; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
 define dso_local ptx_kernel void @read_only_gep(ptr nocapture noundef writeonly %out, ptr nocapture noundef readonly byval(%struct.S) align 4 %s) local_unnamed_addr #0 {
 ; LOWER-ARGS-LABEL: define dso_local ptx_kernel void @read_only_gep(
-; LOWER-ARGS-SAME: ptr noundef writeonly captures(none) [[OUT:%.*]], ptr noundef readonly byval([[STRUCT_S:%.*]]) align 4 captures(none) [[S:%.*]]) local_unnamed_addr #[[ATTR0]] {
+; LOWER-ARGS-SAME: ptr noundef writeonly captures(none) [[OUT:%.*]], ptr noundef readonly byval([[STRUCT_S:%.*]]) align 4 captures(none) "nvvm.grid_constant" [[S:%.*]]) local_unnamed_addr #[[ATTR0]] {
 ; LOWER-ARGS-NEXT:  [[ENTRY:.*:]]
 ; LOWER-ARGS-NEXT:    [[S3:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[S]])
 ; LOWER-ARGS-NEXT:    [[B4:%.*]] = getelementptr inbounds i8, ptr addrspace(101) [[S3]], i64 4
@@ -125,7 +125,7 @@ entry:
 ; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
 define dso_local ptx_kernel void @escape_ptr(ptr nocapture noundef readnone %out, ptr noundef byval(%struct.S) align 4 %s) local_unnamed_addr #0 {
 ; COMMON-LABEL: define dso_local ptx_kernel void @escape_ptr(
-; COMMON-SAME: ptr noundef readnone captures(none) [[OUT:%.*]], ptr noundef byval([[STRUCT_S:%.*]]) align 4 [[S:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+; COMMON-SAME: ptr noundef readnone captures(none) [[OUT:%.*]], ptr noundef readonly byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[S:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 ; COMMON-NEXT:  [[ENTRY:.*:]]
 ; COMMON-NEXT:    [[S1:%.*]] = alloca [[STRUCT_S]], align 4
 ; COMMON-NEXT:    [[S2:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[S]])
@@ -164,7 +164,7 @@ entry:
 ; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
 define dso_local ptx_kernel void @escape_ptr_gep(ptr nocapture noundef readnone %out, ptr noundef byval(%struct.S) align 4 %s) local_unnamed_addr #0 {
 ; COMMON-LABEL: define dso_local ptx_kernel void @escape_ptr_gep(
-; COMMON-SAME: ptr noundef readnone captures(none) [[OUT:%.*]], ptr noundef byval([[STRUCT_S:%.*]]) align 4 [[S:%.*]]) local_unnamed_addr #[[ATTR0]] {
+; COMMON-SAME: ptr noundef readnone captures(none) [[OUT:%.*]], ptr noundef readonly byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[S:%.*]]) local_unnamed_addr #[[ATTR0]] {
 ; COMMON-NEXT:  [[ENTRY:.*:]]
 ; COMMON-NEXT:    [[S1:%.*]] = alloca [[STRUCT_S]], align 4
 ; COMMON-NEXT:    [[S2:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[S]])
@@ -206,7 +206,7 @@ entry:
 ; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
 define dso_local ptx_kernel void @escape_ptr_store(ptr nocapture noundef writeonly %out, ptr noundef byval(%struct.S) align 4 %s) local_unnamed_addr #0 {
 ; COMMON-LABEL: define dso_local ptx_kernel void @escape_ptr_store(
-; COMMON-SAME: ptr noundef writeonly captures(none) [[OUT:%.*]], ptr noundef byval([[STRUCT_S:%.*]]) align 4 [[S:%.*]]) local_unnamed_addr #[[ATTR0]] {
+; COMMON-SAME: ptr noundef writeonly captures(none) [[OUT:%.*]], ptr noundef readonly byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[S:%.*]]) local_unnamed_addr #[[ATTR0]] {
 ; COMMON-NEXT:  [[ENTRY:.*:]]
 ; COMMON-NEXT:    [[S1:%.*]] = alloca [[STRUCT_S]], align 4
 ; COMMON-NEXT:    [[S2:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[S]])
@@ -243,7 +243,7 @@ entry:
 ; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
 define dso_local ptx_kernel void @escape_ptr_gep_store(ptr nocapture noundef writeonly %out, ptr noundef byval(%struct.S) align 4 %s) local_unnamed_addr #0 {
 ; COMMON-LABEL: define dso_local ptx_kernel void @escape_ptr_gep_store(
-; COMMON-SAME: ptr noundef writeonly captures(none) [[OUT:%.*]], ptr noundef byval([[STRUCT_S:%.*]]) align 4 [[S:%.*]]) local_unnamed_addr #[[ATTR0]] {
+; COMMON-SAME: ptr noundef writeonly captures(none) [[OUT:%.*]], ptr noundef readonly byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[S:%.*]]) local_unnamed_addr #[[ATTR0]] {
 ; COMMON-NEXT:  [[ENTRY:.*:]]
 ; COMMON-NEXT:    [[S1:%.*]] = alloca [[STRUCT_S]], align 4
 ; COMMON-NEXT:    [[S2:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[S]])
@@ -283,7 +283,7 @@ entry:
 ; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
 define dso_local ptx_kernel void @escape_ptrtoint(ptr nocapture noundef writeonly %out, ptr noundef byval(%struct.S) align 4 %s) local_unnamed_addr #0 {
 ; COMMON-LABEL: define dso_local ptx_kernel void @escape_ptrtoint(
-; COMMON-SAME: ptr noundef writeonly captures(none) [[OUT:%.*]], ptr noundef byval([[STRUCT_S:%.*]]) align 4 [[S:%.*]]) local_unnamed_addr #[[ATTR0]] {
+; COMMON-SAME: ptr noundef writeonly captures(none) [[OUT:%.*]], ptr noundef readonly byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[S:%.*]]) local_unnamed_addr #[[ATTR0]] {
 ; COMMON-NEXT:  [[ENTRY:.*:]]
 ; COMMON-NEXT:    [[S1:%.*]] = alloca [[STRUCT_S]], align 4
 ; COMMON-NEXT:    [[S2:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[S]])
@@ -322,7 +322,7 @@ entry:
 ; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
 define dso_local ptx_kernel void @memcpy_from_param(ptr nocapture noundef writeonly %out, ptr nocapture noundef readonly byval(%struct.S) align 4 %s) local_unnamed_addr #0 {
 ; LOWER-ARGS-LABEL: define dso_local ptx_kernel void @memcpy_from_param(
-; LOWER-ARGS-SAME: ptr noundef writeonly captures(none) [[OUT:%.*]], ptr noundef readonly byval([[STRUCT_S:%.*]]) align 4 captures(none) [[S:%.*]]) local_unnamed_addr #[[ATTR0]] {
+; LOWER-ARGS-SAME: ptr noundef writeonly captures(none) [[OUT:%.*]], ptr noundef readonly byval([[STRUCT_S:%.*]]) align 4 captures(none) "nvvm.grid_constant" [[S:%.*]]) local_unnamed_addr #[[ATTR0]] {
 ; LOWER-ARGS-NEXT:  [[ENTRY:.*:]]
 ; LOWER-ARGS-NEXT:    [[S3:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[S]])
 ; LOWER-ARGS-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr [[OUT]], ptr addrspace(101) [[S3]], i64 16, i1 true)
@@ -382,9 +382,9 @@ entry:
 ; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
 define dso_local ptx_kernel void @memcpy_from_param_noalign (ptr nocapture noundef writeonly %out, ptr nocapture noundef readonly byval(%struct.S) %s) local_unnamed_addr #0 {
 ; LOWER-ARGS-LABEL: define dso_local ptx_kernel void @memcpy_from_param_noalign(
-; LOWER-ARGS-SAME: ptr noundef writeonly captures(none) [[OUT:%.*]], ptr noundef readonly byval([[STRUCT_S:%.*]]) align 4 captures(none) [[S:%.*]]) local_unnamed_addr #[[ATTR0]] {
+; LOWER-ARGS-SAME: ptr noundef writeonly captures(none) [[OUT:%.*]], ptr noundef readonly byval([[STRUCT_S:%.*]]) align 4 captures(none) "nvvm.grid_constant" [[S:%.*]]) local_unnamed_addr #[[ATTR0]] {
 ; LOWER-ARGS-NEXT:  [[ENTRY:.*:]]
-; LOWER-ARGS-NEXT:    [[S3:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[S]])
+; LOWER-ARGS-NEXT:    [[S3:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[S]])
 ; LOWER-ARGS-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr [[OUT]], ptr addrspace(101) [[S3]], i64 16, i1 true)
 ; LOWER-ARGS-NEXT:    ret void
 ;
@@ -442,7 +442,7 @@ entry:
 ; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
 define dso_local ptx_kernel void @memcpy_to_param(ptr nocapture noundef readonly %in, ptr nocapture noundef byval(%struct.S) align 4 %s) local_unnamed_addr #0 {
 ; COMMON-LABEL: define dso_local ptx_kernel void @memcpy_to_param(
-; COMMON-SAME: ptr noundef readonly captures(none) [[IN:%.*]], ptr noundef byval([[STRUCT_S:%.*]]) align 4 captures(none) [[S:%.*]]) local_unnamed_addr #[[ATTR0]] {
+; COMMON-SAME: ptr noundef readonly captures(none) [[IN:%.*]], ptr noundef readonly byval([[STRUCT_S:%.*]]) align 4 captures(none) "nvvm.grid_constant" [[S:%.*]]) local_unnamed_addr #[[ATTR0]] {
 ; COMMON-NEXT:  [[ENTRY:.*:]]
 ; COMMON-NEXT:    [[S1:%.*]] = alloca [[STRUCT_S]], align 4
 ; COMMON-NEXT:    [[S2:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[S]])
@@ -522,7 +522,7 @@ entry:
 ; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
 define dso_local ptx_kernel void @copy_on_store(ptr nocapture noundef readonly %in, ptr nocapture noundef byval(%struct.S) align 4 %s, i1 noundef zeroext %b) local_unnamed_addr #0 {
 ; COMMON-LABEL: define dso_local ptx_kernel void @copy_on_store(
-; COMMON-SAME: ptr noundef readonly captures(none) [[IN:%.*]], ptr noundef byval([[STRUCT_S:%.*]]) align 4 captures(none) [[S:%.*]], i1 noundef zeroext [[B:%.*]]) local_unnamed_addr #[[ATTR0]] {
+; COMMON-SAME: ptr noundef readonly captures(none) [[IN:%.*]], ptr noundef readonly byval([[STRUCT_S:%.*]]) align 4 captures(none) "nvvm.grid_constant" [[S:%.*]], i1 noundef zeroext [[B:%.*]]) local_unnamed_addr #[[ATTR0]] {
 ; COMMON-NEXT:  [[BB:.*:]]
 ; COMMON-NEXT:    [[S1:%.*]] = alloca [[STRUCT_S]], align 4
 ; COMMON-NEXT:    [[S2:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[S]])
@@ -545,10 +545,10 @@ bb:
 
 define ptx_kernel void @test_select(ptr byval(i32) align 4 %input1, ptr byval(i32) %input2, ptr %out, i1 %cond) {
 ; SM_60-LABEL: define ptx_kernel void @test_select(
-; SM_60-SAME: ptr byval(i32) align 4 [[INPUT1:%.*]], ptr byval(i32) [[INPUT2:%.*]], ptr [[OUT:%.*]], i1 [[COND:%.*]]) #[[ATTR3:[0-9]+]] {
+; SM_60-SAME: ptr readonly byval(i32) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr readonly byval(i32) align 4 "nvvm.grid_constant" [[INPUT2:%.*]], ptr [[OUT:%.*]], i1 [[COND:%.*]]) #[[ATTR3:[0-9]+]] {
 ; SM_60-NEXT:  [[BB:.*:]]
 ; SM_60-NEXT:    [[INPUT24:%.*]] = alloca i32, align 4
-; SM_60-NEXT:    [[INPUT25:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
+; SM_60-NEXT:    [[INPUT25:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
 ; SM_60-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 4 [[INPUT24]], ptr addrspace(101) align 4 [[INPUT25]], i64 4, i1 false)
 ; SM_60-NEXT:    [[INPUT11:%.*]] = alloca i32, align 4
 ; SM_60-NEXT:    [[INPUT12:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT1]])
@@ -559,9 +559,9 @@ define ptx_kernel void @test_select(ptr byval(i32) align 4 %input1, ptr byval(i3
 ; SM_60-NEXT:    ret void
 ;
 ; SM_70-LABEL: define ptx_kernel void @test_select(
-; SM_70-SAME: ptr byval(i32) align 4 [[INPUT1:%.*]], ptr byval(i32) [[INPUT2:%.*]], ptr [[OUT:%.*]], i1 [[COND:%.*]]) #[[ATTR3:[0-9]+]] {
+; SM_70-SAME: ptr readonly byval(i32) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr readonly byval(i32) align 4 "nvvm.grid_constant" [[INPUT2:%.*]], ptr [[OUT:%.*]], i1 [[COND:%.*]]) #[[ATTR3:[0-9]+]] {
 ; SM_70-NEXT:  [[BB:.*:]]
-; SM_70-NEXT:    [[TMP0:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
+; SM_70-NEXT:    [[TMP0:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
 ; SM_70-NEXT:    [[INPUT2_PARAM_GEN:%.*]] = addrspacecast ptr addrspace(101) [[TMP0]] to ptr
 ; SM_70-NEXT:    [[TMP1:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT1]])
 ; SM_70-NEXT:    [[INPUT1_PARAM_GEN:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
@@ -571,7 +571,7 @@ define ptx_kernel void @test_select(ptr byval(i32) align 4 %input1, ptr byval(i3
 ; SM_70-NEXT:    ret void
 ;
 ; COPY-LABEL: define ptx_kernel void @test_select(
-; COPY-SAME: ptr byval(i32) align 4 [[INPUT1:%.*]], ptr byval(i32) [[INPUT2:%.*]], ptr [[OUT:%.*]], i1 [[COND:%.*]]) #[[ATTR3:[0-9]+]] {
+; COPY-SAME: ptr readonly byval(i32) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr readonly byval(i32) "nvvm.grid_constant" [[INPUT2:%.*]], ptr [[OUT:%.*]], i1 [[COND:%.*]]) #[[ATTR3:[0-9]+]] {
 ; COPY-NEXT:  [[BB:.*:]]
 ; COPY-NEXT:    [[INPUT23:%.*]] = alloca i32, align 4
 ; COPY-NEXT:    [[INPUT24:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
@@ -611,18 +611,31 @@ bb:
 }
 
 define ptx_kernel void @test_select_write(ptr byval(i32) align 4 %input1, ptr byval(i32) %input2, ptr %out, i1 %cond) {
-; COMMON-LABEL: define ptx_kernel void @test_select_write(
-; COMMON-SAME: ptr byval(i32) align 4 [[INPUT1:%.*]], ptr byval(i32) [[INPUT2:%.*]], ptr [[OUT:%.*]], i1 [[COND:%.*]]) #[[ATTR3:[0-9]+]] {
-; COMMON-NEXT:  [[BB:.*:]]
-; COMMON-NEXT:    [[INPUT23:%.*]] = alloca i32, align 4
-; COMMON-NEXT:    [[INPUT24:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
-; COMMON-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 4 [[INPUT23]], ptr addrspace(101) align 4 [[INPUT24]], i64 4, i1 false)
-; COMMON-NEXT:    [[INPUT11:%.*]] = alloca i32, align 4
-; COMMON-NEXT:    [[INPUT12:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT1]])
-; COMMON-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 4 [[INPUT11]], ptr addrspace(101) align 4 [[INPUT12]], i64 4, i1 false)
-; COMMON-NEXT:    [[PTRNEW:%.*]] = select i1 [[COND]], ptr [[INPUT11]], ptr [[INPUT23]]
-; COMMON-NEXT:    store i32 1, ptr [[PTRNEW]], align 4
-; COMMON-NEXT:    ret void
+; LOWER-ARGS-LABEL: define ptx_kernel void @test_select_write(
+; LOWER-ARGS-SAME: ptr readonly byval(i32) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr readonly byval(i32) align 4 "nvvm.grid_constant" [[INPUT2:%.*]], ptr [[OUT:%.*]], i1 [[COND:%.*]]) #[[ATTR3:[0-9]+]] {
+; LOWER-ARGS-NEXT:  [[BB:.*:]]
+; LOWER-ARGS-NEXT:    [[INPUT22:%.*]] = alloca i32, align 4
+; LOWER-ARGS-NEXT:    [[INPUT2_PARAM:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
+; LOWER-ARGS-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 4 [[INPUT22]], ptr addrspace(101) align 4 [[INPUT2_PARAM]], i64 4, i1 false)
+; LOWER-ARGS-NEXT:    [[INPUT11:%.*]] = alloca i32, align 4
+; LOWER-ARGS-NEXT:    [[INPUT1_PARAM:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT1]])
+; LOWER-ARGS-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 4 [[INPUT11]], ptr addrspace(101) align 4 [[INPUT1_PARAM]], i64 4, i1 false)
+; LOWER-ARGS-NEXT:    [[PTRNEW:%.*]] = select i1 [[COND]], ptr [[INPUT11]], ptr [[INPUT22]]
+; LOWER-ARGS-NEXT:    store i32 1, ptr [[PTRNEW]], align 4
+; LOWER-ARGS-NEXT:    ret void
+;
+; COPY-LABEL: define ptx_kernel void @test_select_write(
+; COPY-SAME: ptr readonly byval(i32) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr readonly byval(i32) "nvvm.grid_constant" [[INPUT2:%.*]], ptr [[OUT:%.*]], i1 [[COND:%.*]]) #[[ATTR3]] {
+; COPY-NEXT:  [[BB:.*:]]
+; COPY-NEXT:    [[INPUT22:%.*]] = alloca i32, align 4
+; COPY-NEXT:    [[INPUT2_PARAM:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
+; COPY-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 4 [[INPUT22]], ptr addrspace(101) align 4 [[INPUT2_PARAM]], i64 4, i1 false)
+; COPY-NEXT:    [[INPUT11:%.*]] = alloca i32, align 4
+; COPY-NEXT:    [[INPUT1_PARAM:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT1]])
+; COPY-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 4 [[INPUT11]], ptr addrspace(101) align 4 [[INPUT1_PARAM]], i64 4, i1 false)
+; COPY-NEXT:    [[PTRNEW:%.*]] = select i1 [[COND]], ptr [[INPUT11]], ptr [[INPUT22]]
+; COPY-NEXT:    store i32 1, ptr [[PTRNEW]], align 4
+; COPY-NEXT:    ret void
 ;
 ; PTX-LABEL: test_select_write(
 ; PTX:       {
@@ -657,20 +670,20 @@ bb:
 
 define ptx_kernel void @test_phi(ptr byval(%struct.S) align 4 %input1, ptr byval(%struct.S) %input2, ptr %inout, i1 %cond) {
 ; SM_60-LABEL: define ptx_kernel void @test_phi(
-; SM_60-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT1:%.*]], ptr byval([[STRUCT_S]]) [[INPUT2:%.*]], ptr [[INOUT:%.*]], i1 [[COND:%.*]]) #[[ATTR3]] {
+; SM_60-SAME: ptr readonly byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr readonly byval([[STRUCT_S]]) align 4 "nvvm.grid_constant" [[INPUT2:%.*]], ptr [[INOUT:%.*]], i1 [[COND:%.*]]) #[[ATTR3]] {
 ; SM_60-NEXT:  [[BB:.*:]]
-; SM_60-NEXT:    [[INPUT24:%.*]] = alloca [[STRUCT_S]], align 8
-; SM_60-NEXT:    [[INPUT25:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
-; SM_60-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 8 [[INPUT24]], ptr addrspace(101) align 8 [[INPUT25]], i64 8, i1 false)
 ; SM_60-NEXT:    [[INPUT11:%.*]] = alloca [[STRUCT_S]], align 4
+; SM_60-NEXT:    [[INPUT2_PARAM:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
+; SM_60-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 4 [[INPUT11]], ptr addrspace(101) align 4 [[INPUT2_PARAM]], i64 8, i1 false)
+; SM_60-NEXT:    [[INPUT13:%.*]] = alloca [[STRUCT_S]], align 4
 ; SM_60-NEXT:    [[INPUT12:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT1]])
-; SM_60-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 4 [[INPUT11]], ptr addrspace(101) align 4 [[INPUT12]], i64 8, i1 false)
+; SM_60-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 4 [[INPUT13]], ptr addrspace(101) align 4 [[INPUT12]], i64 8, i1 false)
 ; SM_60-NEXT:    br i1 [[COND]], label %[[FIRST:.*]], label %[[SECOND:.*]]
 ; SM_60:       [[FIRST]]:
-; SM_60-NEXT:    [[PTR1:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT11]], i32 0, i32 0
+; SM_60-NEXT:    [[PTR1:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT13]], i32 0, i32 0
 ; SM_60-NEXT:    br label %[[MERGE:.*]]
 ; SM_60:       [[SECOND]]:
-; SM_60-NEXT:    [[PTR2:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT24]], i32 0, i32 1
+; SM_60-NEXT:    [[PTR2:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT11]], i32 0, i32 1
 ; SM_60-NEXT:    br label %[[MERGE]]
 ; SM_60:       [[MERGE]]:
 ; SM_60-NEXT:    [[PTRNEW:%.*]] = phi ptr [ [[PTR1]], %[[FIRST]] ], [ [[PTR2]], %[[SECOND]] ]
@@ -679,9 +692,9 @@ define ptx_kernel void @test_phi(ptr byval(%struct.S) align 4 %input1, ptr byval
 ; SM_60-NEXT:    ret void
 ;
 ; SM_70-LABEL: define ptx_kernel void @test_phi(
-; SM_70-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT1:%.*]], ptr byval([[STRUCT_S]]) [[INPUT2:%.*]], ptr [[INOUT:%.*]], i1 [[COND:%.*]]) #[[ATTR3]] {
+; SM_70-SAME: ptr readonly byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr readonly byval([[STRUCT_S]]) align 4 "nvvm.grid_constant" [[INPUT2:%.*]], ptr [[INOUT:%.*]], i1 [[COND:%.*]]) #[[ATTR3]] {
 ; SM_70-NEXT:  [[BB:.*:]]
-; SM_70-NEXT:    [[TMP0:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
+; SM_70-NEXT:    [[TMP0:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
 ; SM_70-NEXT:    [[INPUT2_PARAM_GEN:%.*]] = addrspacecast ptr addrspace(101) [[TMP0]] to ptr
 ; SM_70-NEXT:    [[TMP1:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT1]])
 ; SM_70-NEXT:    [[INPUT1_PARAM_GEN:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
@@ -699,7 +712,7 @@ define ptx_kernel void @test_phi(ptr byval(%struct.S) align 4 %input1, ptr byval
 ; SM_70-NEXT:    ret void
 ;
 ; COPY-LABEL: define ptx_kernel void @test_phi(
-; COPY-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT1:%.*]], ptr byval([[STRUCT_S]]) [[INPUT2:%.*]], ptr [[INOUT:%.*]], i1 [[COND:%.*]]) #[[ATTR3]] {
+; COPY-SAME: ptr readonly byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr readonly byval([[STRUCT_S]]) "nvvm.grid_constant" [[INPUT2:%.*]], ptr [[INOUT:%.*]], i1 [[COND:%.*]]) #[[ATTR3]] {
 ; COPY-NEXT:  [[BB:.*:]]
 ; COPY-NEXT:    [[INPUT23:%.*]] = alloca [[STRUCT_S]], align 8
 ; COPY-NEXT:    [[INPUT24:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
@@ -782,26 +795,47 @@ merge:                                            ; preds = %second, %first
 }
 
 define ptx_kernel void @test_phi_write(ptr byval(%struct.S) align 4 %input1, ptr byval(%struct.S) %input2, i1 %cond) {
-; COMMON-LABEL: define ptx_kernel void @test_phi_write(
-; COMMON-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT1:%.*]], ptr byval([[STRUCT_S]]) [[INPUT2:%.*]], i1 [[COND:%.*]]) #[[ATTR3]] {
-; COMMON-NEXT:  [[BB:.*:]]
-; COMMON-NEXT:    [[INPUT24:%.*]] = alloca [[STRUCT_S]], align 8
-; COMMON-NEXT:    [[INPUT25:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
-; COMMON-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 8 [[INPUT24]], ptr addrspace(101) align 8 [[INPUT25]], i64 8, i1 false)
-; COMMON-NEXT:    [[INPUT11:%.*]] = alloca [[STRUCT_S]], align 4
-; COMMON-NEXT:    [[INPUT12:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT1]])
-; COMMON-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 4 [[INPUT11]], ptr addrspace(101) align 4 [[INPUT12]], i64 8, i1 false)
-; COMMON-NEXT:    br i1 [[COND]], label %[[FIRST:.*]], label %[[SECOND:.*]]
-; COMMON:       [[FIRST]]:
-; COMMON-NEXT:    [[PTR1:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT11]], i32 0, i32 0
-; COMMON-NEXT:    br label %[[MERGE:.*]]
-; COMMON:       [[SECOND]]:
-; COMMON-NEXT:    [[PTR2:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT24]], i32 0, i32 1
-; COMMON-NEXT:    br label %[[MERGE]]
-; COMMON:       [[MERGE]]:
-; COMMON-NEXT:    [[PTRNEW:%.*]] = phi ptr [ [[PTR1]], %[[FIRST]] ], [ [[PTR2]], %[[SECOND]] ]
-; COMMON-NEXT:    store i32 1, ptr [[PTRNEW]], align 4
-; COMMON-NEXT:    ret void
+; LOWER-ARGS-LABEL: define ptx_kernel void @test_phi_write(
+; LOWER-ARGS-SAME: ptr readonly byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr readonly byval([[STRUCT_S]]) align 4 "nvvm.grid_constant" [[INPUT2:%.*]], i1 [[COND:%.*]]) #[[ATTR3]] {
+; LOWER-ARGS-NEXT:  [[BB:.*:]]
+; LOWER-ARGS-NEXT:    [[INPUT22:%.*]] = alloca [[STRUCT_S]], align 4
+; LOWER-ARGS-NEXT:    [[INPUT2_PARAM:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
+; LOWER-ARGS-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 4 [[INPUT22]], ptr addrspace(101) align 4 [[INPUT2_PARAM]], i64 8, i1 false)
+; LOWER-ARGS-NEXT:    [[INPUT11:%.*]] = alloca [[STRUCT_S]], align 4
+; LOWER-ARGS-NEXT:    [[INPUT1_PARAM:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT1]])
+; LOWER-ARGS-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 4 [[INPUT11]], ptr addrspace(101) align 4 [[INPUT1_PARAM]], i64 8, i1 false)
+; LOWER-ARGS-NEXT:    br i1 [[COND]], label %[[FIRST:.*]], label %[[SECOND:.*]]
+; LOWER-ARGS:       [[FIRST]]:
+; LOWER-ARGS-NEXT:    [[PTR1:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT11]], i32 0, i32 0
+; LOWER-ARGS-NEXT:    br label %[[MERGE:.*]]
+; LOWER-ARGS:       [[SECOND]]:
+; LOWER-ARGS-NEXT:    [[PTR2:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT22]], i32 0, i32 1
+; LOWER-ARGS-NEXT:    br label %[[MERGE]]
+; LOWER-ARGS:       [[MERGE]]:
+; LOWER-ARGS-NEXT:    [[PTRNEW:%.*]] = phi ptr [ [[PTR1]], %[[FIRST]] ], [ [[PTR2]], %[[SECOND]] ]
+; LOWER-ARGS-NEXT:    store i32 1, ptr [[PTRNEW]], align 4
+; LOWER-ARGS-NEXT:    ret void
+;
+; COPY-LABEL: define ptx_kernel void @test_phi_write(
+; COPY-SAME: ptr readonly byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr readonly byval([[STRUCT_S]]) "nvvm.grid_constant" [[INPUT2:%.*]], i1 [[COND:%.*]]) #[[ATTR3]] {
+; COPY-NEXT:  [[BB:.*:]]
+; COPY-NEXT:    [[INPUT22:%.*]] = alloca [[STRUCT_S]], align 8
+; COPY-NEXT:    [[INPUT2_PARAM:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
+; COPY-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 8 [[INPUT22]], ptr addrspace(101) align 8 [[INPUT2_PARAM]], i64 8, i1 false)
+; COPY-NEXT:    [[INPUT11:%.*]] = alloca [[STRUCT_S]], align 4
+; COPY-NEXT:    [[INPUT1_PARAM:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT1]])
+; COPY-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 4 [[INPUT11]], ptr addrspace(101) align 4 [[INPUT1_PARAM]], i64 8, i1 false)
+; COPY-NEXT:    br i1 [[COND]], label %[[FIRST:.*]], label %[[SECOND:.*]]
+; COPY:       [[FIRST]]:
+; COPY-NEXT:    [[PTR1:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT11]], i32 0, i32 0
+; COPY-NEXT:    br label %[[MERGE:.*]]
+; COPY:       [[SECOND]]:
+; COPY-NEXT:    [[PTR2:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT22]], i32 0, i32 1
+; COPY-NEXT:    br label %[[MERGE]]
+; COPY:       [[MERGE]]:
+; COPY-NEXT:    [[PTRNEW:%.*]] = phi ptr [ [[PTR1]], %[[FIRST]] ], [ [[PTR2]], %[[SECOND]] ]
+; COPY-NEXT:    store i32 1, ptr [[PTRNEW]], align 4
+; COPY-NEXT:    ret void
 ;
 ; PTX-LABEL: test_phi_write(
 ; PTX:       {
@@ -850,7 +884,7 @@ merge:                                            ; preds = %second, %first
 
 define ptx_kernel void @test_forward_byval_arg(ptr byval(i32) align 4 %input) {
 ; COMMON-LABEL: define ptx_kernel void @test_forward_byval_arg(
-; COMMON-SAME: ptr byval(i32) align 4 [[INPUT:%.*]]) #[[ATTR3]] {
+; COMMON-SAME: ptr readonly byval(i32) align 4 "nvvm.grid_constant" [[INPUT:%.*]]) #[[ATTR3:[0-9]+]] {
 ; COMMON-NEXT:    [[INPUT1:%.*]] = alloca i32, align 4
 ; COMMON-NEXT:    [[INPUT2:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
 ; COMMON-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 4 [[INPUT1]], ptr addrspace(101) align 4 [[INPUT2]], i64 4, i1 false)

>From b27973b921a611074264e91a61bf26d78ee18ca0 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Fri, 13 Feb 2026 04:53:57 +0000
Subject: [PATCH 2/3] address comments

---
 llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp    | 86 +++++++--------------
 llvm/test/CodeGen/NVPTX/lower-byval-args.ll |  2 +-
 2 files changed, 30 insertions(+), 58 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index c9d761345925d..4773e293c1f39 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -140,6 +140,7 @@
 #include "NVPTXTargetMachine.h"
 #include "NVPTXUtilities.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Analysis/PtrUseVisitor.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/CodeGen/TargetPassConfig.h"
@@ -188,14 +189,14 @@ INITIALIZE_PASS_END(NVPTXLowerArgsLegacyPass, "nvptx-lower-args",
                     "Lower arguments (NVPTX)", false, false)
 
 // =============================================================================
-// If the function had a byval struct ptr arg, say foo(%struct.x* byval %d),
+// If the function had a byval struct ptr arg, say foo(ptr byval(%struct.x) %d),
 // and we can't guarantee that the only accesses are loads,
 // then add the following instructions to the first basic block:
 //
 // %temp = alloca %struct.x, align 8
-// %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)*
-// %tv = load %struct.x addrspace(101)* %tempd
-// store %struct.x %tv, %struct.x* %temp, align 8
+// %tempd = addrspacecast ptr %d to ptr addrspace(101)
+// %tv = load %struct.x, ptr addrspace(101) %tempd
+// store %struct.x %tv, ptr %temp, align 8
 //
 // The above code allocates some space in the stack and copies the incoming
 // struct from param space to local space.
@@ -205,49 +206,42 @@ INITIALIZE_PASS_END(NVPTXLowerArgsLegacyPass, "nvptx-lower-args",
 // ones in parameter AS, so we can access them using ld.param.
 // =============================================================================
 
-// For Loads, replaces the \p OldUse of the pointer with a Use of the same
-// pointer in parameter AS.
-// For "escapes" (to memory, a function call, or a ptrtoint), cast the OldUse to
-// generic using cvta.param.
-static void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam) {
-  Instruction *I = dyn_cast<Instruction>(OldUse->getUser());
-  assert(I && "OldUse must be in an instruction");
+/// Recursively convert the users of a param to the param address space.
+static void convertToParamAS(ArrayRef<Use *> OldUses, Value *Param) {
   struct IP {
     Use *OldUse;
-    Instruction *OldInstruction;
     Value *NewParam;
   };
-  SmallVector<IP> ItemsToConvert = {{OldUse, I, Param}};
-  SmallVector<Instruction *> InstructionsToDelete;
 
-  auto CloneInstInParamAS = [HasCvtaParam](const IP &I) -> Value * {
-    if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) {
+  const auto CloneInstInParamAS = [](const IP &I) -> Value * {
+    auto *OldInst = cast<Instruction>(I.OldUse->getUser());
+    if (auto *LI = dyn_cast<LoadInst>(OldInst)) {
       LI->setOperand(0, I.NewParam);
       return LI;
     }
-    if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) {
+    if (auto *GEP = dyn_cast<GetElementPtrInst>(OldInst)) {
       SmallVector<Value *, 4> Indices(GEP->indices());
       auto *NewGEP = GetElementPtrInst::Create(
           GEP->getSourceElementType(), I.NewParam, Indices, GEP->getName(),
           GEP->getIterator());
-      NewGEP->setIsInBounds(GEP->isInBounds());
+      NewGEP->setNoWrapFlags(GEP->getNoWrapFlags());
       return NewGEP;
     }
-    if (auto *BC = dyn_cast<BitCastInst>(I.OldInstruction)) {
+    if (auto *BC = dyn_cast<BitCastInst>(OldInst)) {
       auto *NewBCType = PointerType::get(BC->getContext(), ADDRESS_SPACE_PARAM);
       return BitCastInst::Create(BC->getOpcode(), I.NewParam, NewBCType,
                                  BC->getName(), BC->getIterator());
     }
-    if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I.OldInstruction)) {
+    if (auto *ASC = dyn_cast<AddrSpaceCastInst>(OldInst)) {
       assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM);
       (void)ASC;
       // Just pass through the argument, the old ASC is no longer needed.
       return I.NewParam;
     }
-    if (auto *MI = dyn_cast<MemTransferInst>(I.OldInstruction)) {
+    if (auto *MI = dyn_cast<MemTransferInst>(OldInst)) {
       if (MI->getRawSource() == I.OldUse->get()) {
         // convert to memcpy/memmove from param space.
-        IRBuilder<> Builder(I.OldInstruction);
+        IRBuilder<> Builder(OldInst);
         Intrinsic::ID ID = MI->getIntrinsicID();
 
         CallInst *B = Builder.CreateMemTransferInst(
@@ -258,55 +252,34 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam) {
             B->addDereferenceableParamAttr(I, Bytes);
         return B;
       }
-      // We may be able to handle other cases if the argument is
-      // __grid_constant__
-    }
-
-    if (HasCvtaParam) {
-      auto GetParamAddrCastToGeneric =
-          [](Value *Addr, Instruction *OriginalUser) -> Value * {
-        IRBuilder<> IRB(OriginalUser);
-        Type *GenTy = IRB.getPtrTy(ADDRESS_SPACE_GENERIC);
-        return IRB.CreateAddrSpaceCast(Addr, GenTy, Addr->getName() + ".gen");
-      };
-      auto *ParamInGenericAS =
-          GetParamAddrCastToGeneric(I.NewParam, I.OldInstruction);
-
-      // phi/select could use generic arg pointers w/o __grid_constant__
-      if (auto *PHI = dyn_cast<PHINode>(I.OldInstruction)) {
-        for (auto [Idx, V] : enumerate(PHI->incoming_values())) {
-          if (V.get() == I.OldUse->get())
-            PHI->setIncomingValue(Idx, ParamInGenericAS);
-        }
-      }
-      if (auto *SI = dyn_cast<SelectInst>(I.OldInstruction)) {
-        if (SI->getTrueValue() == I.OldUse->get())
-          SI->setTrueValue(ParamInGenericAS);
-        if (SI->getFalseValue() == I.OldUse->get())
-          SI->setFalseValue(ParamInGenericAS);
-      }
     }
 
     llvm_unreachable("Unsupported instruction");
   };
 
+  auto ItemsToConvert = map_to_vector(OldUses, [=](Use *U) -> IP {
+    return {U, Param};
+  });
+  SmallVector<Instruction *> InstructionsToDelete;
+
   while (!ItemsToConvert.empty()) {
     IP I = ItemsToConvert.pop_back_val();
     Value *NewInst = CloneInstInParamAS(I);
+    Instruction *OldInst = cast<Instruction>(I.OldUse->getUser());
 
-    if (NewInst && NewInst != I.OldInstruction) {
+    if (NewInst && NewInst != OldInst) {
       // We've created a new instruction. Queue users of the old instruction to
       // be converted and the instruction itself to be deleted. We can't delete
       // the old instruction yet, because it's still in use by a load somewhere.
-      for (Use &U : I.OldInstruction->uses())
-        ItemsToConvert.push_back({&U, cast<Instruction>(U.getUser()), NewInst});
+      for (Use &U : OldInst->uses())
+        ItemsToConvert.push_back({&U, NewInst});
 
-      InstructionsToDelete.push_back(I.OldInstruction);
+      InstructionsToDelete.push_back(OldInst);
     }
   }
 
   // Now we know that all argument loads are using addresses in parameter space
-  // and we can finally remove the old instructions in generic AS.  Instructions
+  // and we can finally remove the old instructions in generic AS. Instructions
   // scheduled for removal should be processed in reverse order so the ones
   // closest to the load are deleted first. Otherwise they may still be in use.
   // E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will
@@ -528,7 +501,6 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
   Function *F = Arg->getParent();
   assert(isKernelFunction(*F));
   const NVPTXSubtarget *ST = TM.getSubtargetImpl(*F);
-  const bool HasCvtaParam = ST->hasCvtaParam();
 
   const DataLayout &DL = F->getDataLayout();
   IRBuilder<> IRB(&F->getEntryBlock().front());
@@ -549,14 +521,14 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
     SmallVector<Use *, 16> UsesToUpdate(llvm::make_pointer_range(Arg->uses()));
     Value *ArgInParamAS = createNVVMInternalAddrspaceWrap(IRB, *Arg);
     for (Use *U : UsesToUpdate)
-      convertToParamAS(U, ArgInParamAS, HasCvtaParam);
+      convertToParamAS(U, ArgInParamAS);
 
     propagateAlignmentToLoads(ArgInParamAS, NewArgAlign, DL);
     return;
   }
 
   // (2) If the argument is grid constant, we get to use the pointer directly.
-  if (HasCvtaParam && (ArgUseIsReadOnly || isParamGridConstant(*Arg))) {
+  if (ST->hasCvtaParam() && (ArgUseIsReadOnly || isParamGridConstant(*Arg))) {
     LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");
 
     // Cast argument to param address space. Because the backend will emit the
diff --git a/llvm/test/CodeGen/NVPTX/lower-byval-args.ll b/llvm/test/CodeGen/NVPTX/lower-byval-args.ll
index 31dddd4e2784e..827097e90e7d3 100644
--- a/llvm/test/CodeGen/NVPTX/lower-byval-args.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-byval-args.ll
@@ -67,7 +67,7 @@ define dso_local ptx_kernel void @read_only_gep(ptr nocapture noundef writeonly
 ; LOWER-ARGS-SAME: ptr noundef writeonly captures(none) [[OUT:%.*]], ptr noundef readonly byval([[STRUCT_S:%.*]]) align 4 captures(none) "nvvm.grid_constant" [[S:%.*]]) local_unnamed_addr #[[ATTR0]] {
 ; LOWER-ARGS-NEXT:  [[ENTRY:.*:]]
 ; LOWER-ARGS-NEXT:    [[S3:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[S]])
-; LOWER-ARGS-NEXT:    [[B4:%.*]] = getelementptr inbounds i8, ptr addrspace(101) [[S3]], i64 4
+; LOWER-ARGS-NEXT:    [[B4:%.*]] = getelementptr inbounds nuw i8, ptr addrspace(101) [[S3]], i64 4
 ; LOWER-ARGS-NEXT:    [[I:%.*]] = load i32, ptr addrspace(101) [[B4]], align 4
 ; LOWER-ARGS-NEXT:    store i32 [[I]], ptr [[OUT]], align 4
 ; LOWER-ARGS-NEXT:    ret void

>From 3fa04eb3f5ca23d2a7cbf679f04a1f031982ad2d Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Fri, 13 Feb 2026 05:08:19 +0000
Subject: [PATCH 3/3] clang format

---
 llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index 4773e293c1f39..aa7f60a84f3b1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -257,9 +257,8 @@ static void convertToParamAS(ArrayRef<Use *> OldUses, Value *Param) {
     llvm_unreachable("Unsupported instruction");
   };
 
-  auto ItemsToConvert = map_to_vector(OldUses, [=](Use *U) -> IP {
-    return {U, Param};
-  });
+  auto ItemsToConvert =
+      map_to_vector(OldUses, [=](Use *U) -> IP { return {U, Param}; });
   SmallVector<Instruction *> InstructionsToDelete;
 
   while (!ItemsToConvert.empty()) {



More information about the llvm-commits mailing list