[llvm] [NVPTX] Cleanup NVPTXLowerArgs, simplifying logic and improving alignment propagation (PR #180286)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 12 20:18:49 PST 2026
================
@@ -491,120 +466,119 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
Base::visitAddrSpaceCastInst(ASC);
}
- void visitPtrToIntInst(PtrToIntInst &I) {
- if (IsGridConstant)
- return;
- Base::visitPtrToIntInst(I);
- }
+ void visitPtrToIntInst(PtrToIntInst &I) { Base::visitPtrToIntInst(I); }
+
void visitPHINodeOrSelectInst(Instruction &I) {
assert(isa<PHINode>(I) || isa<SelectInst>(I));
+ enqueueUsers(I);
+ Conditionals.insert(&I);
}
// PHI and select just pass through the pointers.
- void visitPHINode(PHINode &PN) { enqueueUsers(PN); }
- void visitSelectInst(SelectInst &SI) { enqueueUsers(SI); }
+ void visitPHINode(PHINode &PN) { visitPHINodeOrSelectInst(PN); }
+ void visitSelectInst(SelectInst &SI) { visitPHINodeOrSelectInst(SI); }
+ // memcpy/memmove are OK when the pointer is source. We can convert them to
+ // AS-specific memcpy.
void visitMemTransferInst(MemTransferInst &II) {
- if (*U == II.getRawDest() && !IsGridConstant)
+ if (*U == II.getRawDest())
PI.setAborted(&II);
- // memcpy/memmove are OK when the pointer is source. We can convert them to
- // AS-specific memcpy.
}
- void visitMemSetInst(MemSetInst &II) {
- if (!IsGridConstant)
- PI.setAborted(&II);
- }
+ void visitMemSetInst(MemSetInst &II) { PI.setAborted(&II); }
}; // struct ArgUseChecker
void copyByValParam(Function &F, Argument &Arg) {
LLVM_DEBUG(dbgs() << "Creating a local copy of " << Arg << "\n");
- // Otherwise we have to create a temporary copy.
- BasicBlock::iterator FirstInst = F.getEntryBlock().begin();
- Type *StructType = Arg.getParamByValType();
+ Type *ByValType = Arg.getParamByValType();
const DataLayout &DL = F.getDataLayout();
- IRBuilder<> IRB(&*FirstInst);
- AllocaInst *AllocA = IRB.CreateAlloca(StructType, nullptr, Arg.getName());
+ IRBuilder<> IRB(&F.getEntryBlock().front());
+ AllocaInst *AllocA = IRB.CreateAlloca(ByValType, nullptr, Arg.getName());
// Set the alignment to alignment of the byval parameter. This is because,
// later load/stores assume that alignment, and we are going to replace
// the use of the byval parameter with this alloca instruction.
AllocA->setAlignment(
- Arg.getParamAlign().value_or(DL.getPrefTypeAlign(StructType)));
+ Arg.getParamAlign().value_or(DL.getPrefTypeAlign(ByValType)));
Arg.replaceAllUsesWith(AllocA);
- CallInst *ArgInParam = createNVVMInternalAddrspaceWrap(IRB, Arg);
+ Value *ArgInParamAS = createNVVMInternalAddrspaceWrap(IRB, Arg);
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
// addrspacecast preserves alignment. Since params are constant, this load
// is definitely not volatile.
const auto ArgSize = *AllocA->getAllocationSize(DL);
- IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParam, AllocA->getAlign(),
+ IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParamAS, AllocA->getAlign(),
ArgSize);
}
} // namespace
+static bool argIsProcessed(Argument *Arg) {
+ if (Arg->use_empty())
+ return true;
+
+ // If the argument is already wrapped, it was processed by this pass before.
+ if (Arg->hasOneUse())
+ if (const auto *II = dyn_cast<IntrinsicInst>(*Arg->user_begin()))
+ if (II->getIntrinsicID() == Intrinsic::nvvm_internal_addrspace_wrap)
+ return true;
+
+ return false;
+}
+
static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
- Function *Func = Arg->getParent();
- assert(isKernelFunction(*Func));
- const bool HasCvtaParam = TM.getSubtargetImpl(*Func)->hasCvtaParam();
- const bool IsGridConstant = HasCvtaParam && isParamGridConstant(*Arg);
- const DataLayout &DL = Func->getDataLayout();
- BasicBlock::iterator FirstInst = Func->getEntryBlock().begin();
- [[maybe_unused]] Type *StructType = Arg->getParamByValType();
- assert(StructType && "Missing byval type");
-
- ArgUseChecker AUC(DL, IsGridConstant);
+ Function *F = Arg->getParent();
+ assert(isKernelFunction(*F));
+ const NVPTXSubtarget *ST = TM.getSubtargetImpl(*F);
+ const bool HasCvtaParam = ST->hasCvtaParam();
+
+ const DataLayout &DL = F->getDataLayout();
+ IRBuilder<> IRB(&F->getEntryBlock().front());
+
+ if (argIsProcessed(Arg))
+ return;
+
+ const Align NewArgAlign = setByValParamAlign(Arg, ST->getTargetLowering());
+
+ // (1) First check the easy case, if were able to trace through all the uses
+ // and we can convert them all to param AS, then we'll do this.
+ ArgUseChecker AUC(DL);
ArgUseChecker::PtrInfo PI = AUC.visitArgPtr(*Arg);
- bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());
- // Easy case, accessing parameter directly is fine.
+ const bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());
if (ArgUseIsReadOnly && AUC.Conditionals.empty()) {
// Convert all loads and intermediate operations to use parameter AS and
// skip creation of a local copy of the argument.
SmallVector<Use *, 16> UsesToUpdate(llvm::make_pointer_range(Arg->uses()));
-
- IRBuilder<> IRB(&*FirstInst);
- CallInst *ArgInParamAS = createNVVMInternalAddrspaceWrap(IRB, *Arg);
-
+ Value *ArgInParamAS = createNVVMInternalAddrspaceWrap(IRB, *Arg);
for (Use *U : UsesToUpdate)
- convertToParamAS(U, ArgInParamAS, HasCvtaParam, IsGridConstant);
- LLVM_DEBUG(dbgs() << "No need to copy or cast " << *Arg << "\n");
-
- const auto *TLI =
- cast<NVPTXTargetLowering>(TM.getSubtargetImpl()->getTargetLowering());
-
- adjustByValArgAlignment(Arg, ArgInParamAS, TLI);
+ convertToParamAS(U, ArgInParamAS, HasCvtaParam);
+ propagateAlignmentToLoads(ArgInParamAS, NewArgAlign, DL);
return;
}
- // We can't access byval arg directly and need a pointer. on sm_70+ we have
- // ability to take a pointer to the argument without making a local copy.
- // However, we're still not allowed to write to it. If the user specified
- // `__grid_constant__` for the argument, we'll consider escaped pointer as
- // read-only.
- if (IsGridConstant || (HasCvtaParam && ArgUseIsReadOnly)) {
----------------
AlexMaclean wrote:
Yea, we already have several test for this such as `@grid_const_escape`. The final codegen is unchanged.
https://github.com/llvm/llvm-project/pull/180286
More information about the llvm-commits
mailing list