[llvm] [NVPTX] Improve device function byval parameter lowering (PR #129188)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 27 21:28:23 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-nvptx
Author: Alex MacLean (AlexMaclean)
<details>
<summary>Changes</summary>
PTX supports 2 methods of accessing device function parameters:
- "simple" case: If a parameters is only loaded, and all loads can address the parameter via a constant offset, then the parameter may be loaded via the ".param" address space. This case is not possible if the parameters is stored to or has it's address taken. This method is preferable when possible.
- "move param" case: For more complex cases the address of the param may be placed in a register via a "mov" instruction. This mov also implicitly moves the param to the ".local" address space and allows for it to be written to. This essentially defers the responsibilty of the byval copy to the PTX calling convention.
The handling of these cases in the NVPTX backend for byval pointers has some major issues. We currently attempt to determine if a copy is necessary in NVPTXLowerArgs and either explicitly make an additional copy in the IR, or insert "addrspacecast" to move the param to the param address space. Unfortunately the criteria for determining which case is possible are not correct, leading to miscompilations (https://godbolt.org/z/Gq1fP7a3G). Further, the criteria for the "simple" case aren't enforceable in LLVM IR across other transformations and instruction selection, making deciding between the 2 cases in NVPTXLowerArgs brittle and buggy.
This patch aims to fix these issues and improve address space related optimization. In NVPTXLowerArgs, we conservatively assume that all parameters will use the "move param" case and the local address space. Responsibility for switching to the "simple" case is given to a new MachineIR pass, NVPTXForwardParams, which runs once it has become clear whether or not this is possible. This ensures that the correct address space is known for the "move param" case allowing for optimization, while still using the "simple" case where ever possible.
---
Patch is 35.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129188.diff
13 Files Affected:
- (modified) llvm/lib/Target/NVPTX/CMakeLists.txt (+1)
- (modified) llvm/lib/Target/NVPTX/NVPTX.h (+1)
- (added) llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp (+169)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+2-2)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+12-4)
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+3-18)
- (modified) llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp (+27-36)
- (modified) llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp (+3)
- (added) llvm/test/CodeGen/NVPTX/forward-ld-param.ll (+142)
- (modified) llvm/test/CodeGen/NVPTX/i128-array.ll (+7-8)
- (modified) llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll (+26-24)
- (modified) llvm/test/CodeGen/NVPTX/lower-args.ll (+8-15)
- (modified) llvm/test/CodeGen/NVPTX/variadics-backend.ll (+10-10)
``````````diff
diff --git a/llvm/lib/Target/NVPTX/CMakeLists.txt b/llvm/lib/Target/NVPTX/CMakeLists.txt
index dfbda84534732..1cffde138eab7 100644
--- a/llvm/lib/Target/NVPTX/CMakeLists.txt
+++ b/llvm/lib/Target/NVPTX/CMakeLists.txt
@@ -16,6 +16,7 @@ set(NVPTXCodeGen_sources
NVPTXAtomicLower.cpp
NVPTXAsmPrinter.cpp
NVPTXAssignValidGlobalNames.cpp
+ NVPTXForwardParams.cpp
NVPTXFrameLowering.cpp
NVPTXGenericToNVVM.cpp
NVPTXISelDAGToDAG.cpp
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index ca915cd3f3732..62f51861ac55a 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -52,6 +52,7 @@ FunctionPass *createNVPTXLowerUnreachablePass(bool TrapUnreachable,
bool NoTrapAfterNoreturn);
MachineFunctionPass *createNVPTXPeephole();
MachineFunctionPass *createNVPTXProxyRegErasurePass();
+MachineFunctionPass *createNVPTXForwardParamsPass();
struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
diff --git a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
new file mode 100644
index 0000000000000..47d44b985363d
--- /dev/null
+++ b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
@@ -0,0 +1,169 @@
+//- NVPTXForwardParams.cpp - NVPTX Forward Device Params Removing Local Copy -//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// PTX supports 2 methods of accessing device function parameters:
+//
+// - "simple" case: If a parameters is only loaded, and all loads can address
+// the parameter via a constant offset, then the parameter may be loaded via
+// the ".param" address space. This case is not possible if the parameters
+// is stored to or has it's address taken. This method is preferable when
+// possible. Ex:
+//
+// ld.param.u32 %r1, [foo_param_1];
+// ld.param.u32 %r2, [foo_param_1+4];
+//
+// - "move param" case: For more complex cases the address of the param may be
+// placed in a register via a "mov" instruction. This "mov" also implicitly
+// moves the param to the ".local" address space and allows for it to be
+// written to. This essentially defers the responsibilty of the byval copy
+// to the PTX calling convention.
+//
+// mov.b64 %rd1, foo_param_0;
+// st.local.u32 [%rd1], 42;
+// add.u64 %rd3, %rd1, %rd2;
+// ld.local.u32 %r2, [%rd3];
+//
+// In NVPTXLowerArgs and SelectionDAG, we pessimistically assume that all
+// parameters will use the "move param" case and the local address space. This
+// pass is responsible for switching to the "simple" case when possible, as it
+// is more efficient.
+//
+// We do this by simply traversing uses of the param "mov" instructions an
+// trivially checking if they are all loads.
+//
+//===----------------------------------------------------------------------===//
+
+#include "NVPTX.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/MachineOperand.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
+#include "llvm/Support/ErrorHandling.h"
+
+using namespace llvm;
+
+static bool traverseMoveUse(MachineInstr &U, const MachineRegisterInfo &MRI,
+ SmallVectorImpl<MachineInstr *> &RemoveList,
+ SmallVectorImpl<MachineInstr *> &LoadInsts) {
+ switch (U.getOpcode()) {
+ case NVPTX::LD_f32:
+ case NVPTX::LD_f64:
+ case NVPTX::LD_i16:
+ case NVPTX::LD_i32:
+ case NVPTX::LD_i64:
+ case NVPTX::LD_i8:
+ case NVPTX::LDV_f32_v2:
+ case NVPTX::LDV_f32_v4:
+ case NVPTX::LDV_f64_v2:
+ case NVPTX::LDV_f64_v4:
+ case NVPTX::LDV_i16_v2:
+ case NVPTX::LDV_i16_v4:
+ case NVPTX::LDV_i32_v2:
+ case NVPTX::LDV_i32_v4:
+ case NVPTX::LDV_i64_v2:
+ case NVPTX::LDV_i64_v4:
+ case NVPTX::LDV_i8_v2:
+ case NVPTX::LDV_i8_v4: {
+ LoadInsts.push_back(&U);
+ return true;
+ }
+ case NVPTX::cvta_local:
+ case NVPTX::cvta_local_64:
+ case NVPTX::cvta_to_local:
+ case NVPTX::cvta_to_local_64: {
+ for (auto &U2 : MRI.use_instructions(U.operands_begin()->getReg()))
+ if (!traverseMoveUse(U2, MRI, RemoveList, LoadInsts))
+ return false;
+
+ RemoveList.push_back(&U);
+ return true;
+ }
+ default:
+ return false;
+ }
+}
+
+static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
+ SmallVectorImpl<MachineInstr *> &RemoveList) {
+ SmallVector<MachineInstr *, 16> MaybeRemoveList;
+ SmallVector<MachineInstr *, 16> LoadInsts;
+
+ for (auto &U : MRI.use_instructions(Mov.operands_begin()->getReg()))
+ if (!traverseMoveUse(U, MRI, MaybeRemoveList, LoadInsts))
+ return false;
+
+ RemoveList.append(MaybeRemoveList);
+ RemoveList.push_back(&Mov);
+
+ const MachineOperand *ParamSymbol = Mov.uses().begin();
+ assert(ParamSymbol->isSymbol());
+
+ constexpr unsigned LDInstBasePtrOpIdx = 6;
+ constexpr unsigned LDInstAddrSpaceOpIdx = 2;
+ for (auto *LI : LoadInsts) {
+ (LI->uses().begin() + LDInstBasePtrOpIdx)
+ ->ChangeToES(ParamSymbol->getSymbolName());
+ (LI->uses().begin() + LDInstAddrSpaceOpIdx)
+ ->ChangeToImmediate(NVPTX::AddressSpace::Param);
+ }
+ return true;
+}
+
+static bool forwardDeviceParams(MachineFunction &MF) {
+ const auto &MRI = MF.getRegInfo();
+
+ bool Changed = false;
+ SmallVector<MachineInstr *, 16> RemoveList;
+ for (auto &MI : make_early_inc_range(*MF.begin()))
+ if (MI.getOpcode() == NVPTX::MOV32_PARAM ||
+ MI.getOpcode() == NVPTX::MOV64_PARAM)
+ Changed |= eliminateMove(MI, MRI, RemoveList);
+
+ for (auto *MI : RemoveList)
+ MI->eraseFromParent();
+
+ return Changed;
+}
+
+/// ----------------------------------------------------------------------------
+/// Pass (Manager) Boilerplate
+/// ----------------------------------------------------------------------------
+
+namespace llvm {
+void initializeNVPTXForwardParamsPassPass(PassRegistry &);
+} // namespace llvm
+
+namespace {
+struct NVPTXForwardParamsPass : public MachineFunctionPass {
+ static char ID;
+ NVPTXForwardParamsPass() : MachineFunctionPass(ID) {
+ initializeNVPTXForwardParamsPassPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnMachineFunction(MachineFunction &MF) override;
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ MachineFunctionPass::getAnalysisUsage(AU);
+ }
+};
+} // namespace
+
+char NVPTXForwardParamsPass::ID = 0;
+
+INITIALIZE_PASS(NVPTXForwardParamsPass, "nvptx-forward-params",
+ "NVPTX Forward Params", false, false)
+
+bool NVPTXForwardParamsPass::runOnMachineFunction(MachineFunction &MF) {
+ return forwardDeviceParams(MF);
+}
+
+MachineFunctionPass *llvm::createNVPTXForwardParamsPass() {
+ return new NVPTXForwardParamsPass();
+}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 8a5cdd7412bf3..0461ed4712221 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -2197,11 +2197,11 @@ static SDValue selectBaseADDR(SDValue N, SelectionDAG *DAG) {
if (N.getOpcode() == NVPTXISD::Wrapper)
return N.getOperand(0);
- // addrspacecast(MoveParam(arg_symbol) to addrspace(PARAM)) -> arg_symbol
+ // addrspacecast(Wrapper(arg_symbol) to addrspace(PARAM)) -> arg_symbol
if (AddrSpaceCastSDNode *CastN = dyn_cast<AddrSpaceCastSDNode>(N))
if (CastN->getSrcAddressSpace() == ADDRESS_SPACE_GENERIC &&
CastN->getDestAddressSpace() == ADDRESS_SPACE_PARAM &&
- CastN->getOperand(0).getOpcode() == NVPTXISD::MoveParam)
+ CastN->getOperand(0).getOpcode() == NVPTXISD::Wrapper)
return selectBaseADDR(CastN->getOperand(0).getOperand(0), DAG);
if (auto *FIN = dyn_cast<FrameIndexSDNode>(N))
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 684098681d1ab..0ae6f9004b458 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3376,10 +3376,18 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
assert(ObjectVT == Ins[InsIdx].VT &&
"Ins type did not match function type");
SDValue Arg = getParamSymbol(DAG, i, PtrVT);
- SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
- if (p.getNode())
- p.getNode()->setIROrder(i + 1);
- InVals.push_back(p);
+
+ SDValue P;
+ if (isKernelFunction(*F)) {
+ P = DAG.getNode(NVPTXISD::Wrapper, dl, ObjectVT, Arg);
+ P.getNode()->setIROrder(i + 1);
+ } else {
+ P = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
+ P.getNode()->setIROrder(i + 1);
+ P = DAG.getAddrSpaceCast(dl, ObjectVT, P, ADDRESS_SPACE_LOCAL,
+ ADDRESS_SPACE_GENERIC);
+ }
+ InVals.push_back(P);
}
if (!OutChains.empty())
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 36a0a06bdb8aa..6edb0998760b8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -2324,7 +2324,7 @@ def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>;
def SDTCallArgMarkProfile : SDTypeProfile<0, 0, []>;
def SDTCallVoidProfile : SDTypeProfile<0, 1, []>;
def SDTCallValProfile : SDTypeProfile<1, 0, []>;
-def SDTMoveParamProfile : SDTypeProfile<1, 1, []>;
+def SDTMoveParamProfile : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>;
def SDTStoreRetvalProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>;
def SDTStoreRetvalV2Profile : SDTypeProfile<0, 3, [SDTCisInt<0>]>;
def SDTStoreRetvalV4Profile : SDTypeProfile<0, 5, [SDTCisInt<0>]>;
@@ -2688,29 +2688,14 @@ def DeclareScalarRegInst :
".reg .b$size param$a;",
[(DeclareScalarParam (i32 imm:$a), (i32 imm:$size), (i32 1))]>;
-class MoveParamInst<ValueType T, NVPTXRegClass regclass, string asmstr> :
- NVPTXInst<(outs regclass:$dst), (ins regclass:$src),
- !strconcat("mov", asmstr, " \t$dst, $src;"),
- [(set T:$dst, (MoveParam T:$src))]>;
-
class MoveParamSymbolInst<NVPTXRegClass regclass, Operand srcty, ValueType vt,
string asmstr> :
NVPTXInst<(outs regclass:$dst), (ins srcty:$src),
!strconcat("mov", asmstr, " \t$dst, $src;"),
[(set vt:$dst, (MoveParam texternalsym:$src))]>;
-def MoveParamI64 : MoveParamInst<i64, Int64Regs, ".b64">;
-def MoveParamI32 : MoveParamInst<i32, Int32Regs, ".b32">;
-
-def MoveParamSymbolI64 : MoveParamSymbolInst<Int64Regs, i64imm, i64, ".b64">;
-def MoveParamSymbolI32 : MoveParamSymbolInst<Int32Regs, i32imm, i32, ".b32">;
-
-def MoveParamI16 :
- NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
- "cvt.u16.u32 \t$dst, $src;", // ??? Why cvt.u16.u32 ?
- [(set i16:$dst, (MoveParam i16:$src))]>;
-def MoveParamF64 : MoveParamInst<f64, Float64Regs, ".f64">;
-def MoveParamF32 : MoveParamInst<f32, Float32Regs, ".f32">;
+def MOV64_PARAM : MoveParamSymbolInst<Int64Regs, i64imm, i64, ".b64">;
+def MOV32_PARAM : MoveParamSymbolInst<Int32Regs, i32imm, i32, ".b32">;
class PseudoUseParamInst<NVPTXRegClass regclass, ValueType vt> :
NVPTXInst<(outs), (ins regclass:$src),
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index c763b54c8dbfe..6cef245e2d98e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -153,6 +153,7 @@
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/NVPTXAddrSpace.h"
#include <numeric>
#include <queue>
@@ -373,19 +374,19 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
Type *StructType = Arg->getParamByValType();
const DataLayout &DL = Func->getDataLayout();
- uint64_t NewArgAlign =
- TLI->getFunctionParamOptimizedAlign(Func, StructType, DL).value();
- uint64_t CurArgAlign =
- Arg->getAttribute(Attribute::Alignment).getValueAsInt();
+ const Align NewArgAlign =
+ TLI->getFunctionParamOptimizedAlign(Func, StructType, DL);
+ const Align CurArgAlign = Arg->getParamAlign().valueOrOne();
if (CurArgAlign >= NewArgAlign)
return;
- LLVM_DEBUG(dbgs() << "Try to use alignment " << NewArgAlign << " instead of "
- << CurArgAlign << " for " << *Arg << '\n');
+ LLVM_DEBUG(dbgs() << "Try to use alignment " << NewArgAlign.value()
+ << " instead of " << CurArgAlign.value() << " for " << *Arg
+ << '\n');
auto NewAlignAttr =
- Attribute::get(Func->getContext(), Attribute::Alignment, NewArgAlign);
+ Attribute::getWithAlignment(Func->getContext(), NewArgAlign);
Arg->removeAttr(Attribute::Alignment);
Arg->addAttr(NewAlignAttr);
@@ -402,7 +403,6 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
SmallVector<Load> Loads;
std::queue<LoadContext> Worklist;
Worklist.push({ArgInParamAS, 0});
- bool IsGridConstant = isParamGridConstant(*Arg);
while (!Worklist.empty()) {
LoadContext Ctx = Worklist.front();
@@ -411,15 +411,9 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
for (User *CurUser : Ctx.InitialVal->users()) {
if (auto *I = dyn_cast<LoadInst>(CurUser)) {
Loads.push_back({I, Ctx.Offset});
- continue;
- }
-
- if (auto *I = dyn_cast<BitCastInst>(CurUser)) {
- Worklist.push({I, Ctx.Offset});
- continue;
- }
-
- if (auto *I = dyn_cast<GetElementPtrInst>(CurUser)) {
+ } else if (isa<BitCastInst>(CurUser) || isa<AddrSpaceCastInst>(CurUser)) {
+ Worklist.push({cast<Instruction>(CurUser), Ctx.Offset});
+ } else if (auto *I = dyn_cast<GetElementPtrInst>(CurUser)) {
APInt OffsetAccumulated =
APInt::getZero(DL.getIndexSizeInBits(ADDRESS_SPACE_PARAM));
@@ -431,26 +425,13 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
assert(Offset != OffsetLimit && "Expect Offset less than UINT64_MAX");
Worklist.push({I, Ctx.Offset + Offset});
- continue;
}
-
- if (isa<MemTransferInst>(CurUser))
- continue;
-
- // supported for grid_constant
- if (IsGridConstant &&
- (isa<CallInst>(CurUser) || isa<StoreInst>(CurUser) ||
- isa<PtrToIntInst>(CurUser)))
- continue;
-
- llvm_unreachable("All users must be one of: load, "
- "bitcast, getelementptr, call, store, ptrtoint");
}
}
for (Load &CurLoad : Loads) {
- Align NewLoadAlign(std::gcd(NewArgAlign, CurLoad.Offset));
- Align CurLoadAlign(CurLoad.Inst->getAlign());
+ Align NewLoadAlign(std::gcd(NewArgAlign.value(), CurLoad.Offset));
+ Align CurLoadAlign = CurLoad.Inst->getAlign();
CurLoad.Inst->setAlignment(std::max(NewLoadAlign, CurLoadAlign));
}
}
@@ -641,7 +622,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
copyByValParam(*Func, *Arg);
}
-void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
+static void markPointerAsAS(Value *Ptr, const unsigned AS) {
if (Ptr->getType()->getPointerAddressSpace() != ADDRESS_SPACE_GENERIC)
return;
@@ -658,7 +639,7 @@ void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
}
Instruction *PtrInGlobal = new AddrSpaceCastInst(
- Ptr, PointerType::get(Ptr->getContext(), ADDRESS_SPACE_GLOBAL),
+ Ptr, PointerType::get(Ptr->getContext(), AS),
Ptr->getName(), InsertPt);
Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(),
Ptr->getName(), InsertPt);
@@ -667,6 +648,10 @@ void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
PtrInGlobal->setOperand(0, Ptr);
}
+void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
+ markPointerAsAS(Ptr, ADDRESS_SPACE_GLOBAL);
+}
+
// =============================================================================
// Main function for this pass.
// =============================================================================
@@ -724,9 +709,15 @@ bool NVPTXLowerArgs::runOnKernelFunction(const NVPTXTargetMachine &TM,
bool NVPTXLowerArgs::runOnDeviceFunction(const NVPTXTargetMachine &TM,
Function &F) {
LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n");
+
+ const auto *TLI =
+ cast<NVPTXTargetLowering>(TM.getSubtargetImpl()->getTargetLowering());
+
for (Argument &Arg : F.args())
- if (Arg.getType()->isPointerTy() && Arg.hasByValAttr())
- handleByValParam(TM, &Arg);
+ if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) {
+ markPointerAsAS(&Arg, ADDRESS_SPACE_LOCAL);
+ adjustByValArgAlignment(&Arg, &Arg, TLI);
+ }
return true;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index f2afa6fc20bfa..229fecf2d3b10 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -100,6 +100,7 @@ void initializeNVPTXLowerUnreachablePass(PassRegistry &);
void initializeNVPTXCtorDtorLoweringLegacyPass(PassRegistry &);
void initializeNVPTXLowerArgsPass(PassRegistry &);
void initializeNVPTXProxyRegErasurePass(PassRegistry &);
+void initializeNVPTXForwardParamsPassPass(PassRegistry &);
void initializeNVVMIntrRangePass(PassRegistry &);
void initializeNVVMReflectPass(PassRegistry &);
void initializeNVPTXAAWrapperPassPass(PassRegistry &);
@@ -127,6 +128,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
initializeNVPTXCtorDtorLoweringLegacyPass(PR);
initializeNVPTXLowerAggrCopiesPass(PR);
initializeNVPTXProxyRegErasurePass(PR);
+ initializeNVPTXForwardParamsPassPass(PR);
initializeNVPTXDAGToDAGISelLegacyPass(PR);
initializeNVPTXAAWrapperPassPass(PR);
initializeNVPTXExternalAAWrapperPass(PR);
@@ -429,6 +431,7 @@ bool NVPTXPassConfig::addInstSelector() {
}
void NVPTXPassConfig::addPreRegAlloc() {
+ addPass(createNVPTXForwardParamsPass());
// Remove Proxy Register pseudo instructions used to keep `callseq_end` alive.
addPass(createNVPTXProxyRegErasurePass());
}
diff --git a/llvm/test/CodeGen/NVPTX/forward-ld-param.ll b/llvm/test/CodeGen/NVPTX/forward-ld-param.ll
new file mode 100644
index 0000000000000..c4e56d197edc0
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/forward-ld-param.ll
@@ -0,0 +1,142 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+
+target triple = "nvptx64-nvidia-cuda"
+
+define i32 @test_ld_param_const(ptr byval(i32) %a) {
+; CHECK-LABEL: test_ld_param_const(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_ld_param_const_param_0+4];
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %p2 = getelementptr i32, ptr %a, i32 1
+ %ld = load i32, ptr %p2
+ ret i32 %ld
+}
+
+define i32 @test_ld_param_non_const(ptr byval([10 x i32]) %a, i32 %b) {
+; CHECK-LABEL: test_ld_param_non_const(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: mov.b64 %rd1, test_ld_param_non_const_param_0;
+; CHECK-NEXT: cvta.local.u64 %rd2, %rd1;
+; CHECK-NEXT: cvta.to.local.u64 %rd3, %rd2;
+; CHECK-NEXT: ld.param.s32 %rd4, [test_ld_param_non_const_param_1];
+; CHECK-NEXT: add.s64 %rd5, %rd3, %rd4;
+; CHECK-NEXT: ld.local.u32 %r1, [%rd5];
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %p2 = getelementptr i8, ptr %a, i32 %b
+ %ld = load i32, ptr %p2
+ ret i32 %ld
+}
+
+declare void @escape(ptr)
+declare void @byval_user(ptr byval(i32))
+
+define void @test_ld_param_e...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/129188
More information about the llvm-commits
mailing list