[llvm] [NVPTX] Improve device function byval parameter lowering (PR #129188)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 27 21:27:52 PST 2025
https://github.com/AlexMaclean created https://github.com/llvm/llvm-project/pull/129188
PTX supports 2 methods of accessing device function parameters:
- "simple" case: If a parameters is only loaded, and all loads can address the parameter via a constant offset, then the parameter may be loaded via the ".param" address space. This case is not possible if the parameters is stored to or has it's address taken. This method is preferable when possible.
- "move param" case: For more complex cases the address of the param may be placed in a register via a "mov" instruction. This mov also implicitly moves the param to the ".local" address space and allows for it to be written to. This essentially defers the responsibilty of the byval copy to the PTX calling convention.
The handling of these cases in the NVPTX backend for byval pointers has some major issues. We currently attempt to determine if a copy is necessary in NVPTXLowerArgs and either explicitly make an additional copy in the IR, or insert "addrspacecast" to move the param to the param address space. Unfortunately the criteria for determining which case is possible are not correct, leading to miscompilations (https://godbolt.org/z/Gq1fP7a3G). Further, the criteria for the "simple" case aren't enforceable in LLVM IR across other transformations and instruction selection, making deciding between the 2 cases in NVPTXLowerArgs brittle and buggy.
This patch aims to fix these issues and improve address space related optimization. In NVPTXLowerArgs, we conservatively assume that all parameters will use the "move param" case and the local address space. Responsibility for switching to the "simple" case is given to a new MachineIR pass, NVPTXForwardParams, which runs once it has become clear whether or not this is possible. This ensures that the correct address space is known for the "move param" case allowing for optimization, while still using the "simple" case where ever possible.
>From e880f61eb8678b63edf2fd281dd104b7e981bb67 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Fri, 28 Feb 2025 05:10:39 +0000
Subject: [PATCH] [NVPTX] Improve byval device parameter lowering
---
llvm/lib/Target/NVPTX/CMakeLists.txt | 1 +
llvm/lib/Target/NVPTX/NVPTX.h | 1 +
llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp | 169 ++++++++++++++++++
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 4 +-
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 16 +-
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 21 +--
llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp | 63 +++----
llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp | 3 +
llvm/test/CodeGen/NVPTX/forward-ld-param.ll | 142 +++++++++++++++
llvm/test/CodeGen/NVPTX/i128-array.ll | 15 +-
.../CodeGen/NVPTX/lower-args-gridconstant.ll | 50 +++---
llvm/test/CodeGen/NVPTX/lower-args.ll | 23 +--
llvm/test/CodeGen/NVPTX/variadics-backend.ll | 20 +--
13 files changed, 411 insertions(+), 117 deletions(-)
create mode 100644 llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
create mode 100644 llvm/test/CodeGen/NVPTX/forward-ld-param.ll
diff --git a/llvm/lib/Target/NVPTX/CMakeLists.txt b/llvm/lib/Target/NVPTX/CMakeLists.txt
index dfbda84534732..1cffde138eab7 100644
--- a/llvm/lib/Target/NVPTX/CMakeLists.txt
+++ b/llvm/lib/Target/NVPTX/CMakeLists.txt
@@ -16,6 +16,7 @@ set(NVPTXCodeGen_sources
NVPTXAtomicLower.cpp
NVPTXAsmPrinter.cpp
NVPTXAssignValidGlobalNames.cpp
+ NVPTXForwardParams.cpp
NVPTXFrameLowering.cpp
NVPTXGenericToNVVM.cpp
NVPTXISelDAGToDAG.cpp
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index ca915cd3f3732..62f51861ac55a 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -52,6 +52,7 @@ FunctionPass *createNVPTXLowerUnreachablePass(bool TrapUnreachable,
bool NoTrapAfterNoreturn);
MachineFunctionPass *createNVPTXPeephole();
MachineFunctionPass *createNVPTXProxyRegErasurePass();
+MachineFunctionPass *createNVPTXForwardParamsPass();
struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
diff --git a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
new file mode 100644
index 0000000000000..47d44b985363d
--- /dev/null
+++ b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
@@ -0,0 +1,169 @@
+//- NVPTXForwardParams.cpp - NVPTX Forward Device Params Removing Local Copy -//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// PTX supports 2 methods of accessing device function parameters:
+//
+// - "simple" case: If a parameters is only loaded, and all loads can address
+// the parameter via a constant offset, then the parameter may be loaded via
+// the ".param" address space. This case is not possible if the parameters
+// is stored to or has it's address taken. This method is preferable when
+// possible. Ex:
+//
+// ld.param.u32 %r1, [foo_param_1];
+// ld.param.u32 %r2, [foo_param_1+4];
+//
+// - "move param" case: For more complex cases the address of the param may be
+// placed in a register via a "mov" instruction. This "mov" also implicitly
+// moves the param to the ".local" address space and allows for it to be
+// written to. This essentially defers the responsibilty of the byval copy
+// to the PTX calling convention.
+//
+// mov.b64 %rd1, foo_param_0;
+// st.local.u32 [%rd1], 42;
+// add.u64 %rd3, %rd1, %rd2;
+// ld.local.u32 %r2, [%rd3];
+//
+// In NVPTXLowerArgs and SelectionDAG, we pessimistically assume that all
+// parameters will use the "move param" case and the local address space. This
+// pass is responsible for switching to the "simple" case when possible, as it
+// is more efficient.
+//
+// We do this by simply traversing uses of the param "mov" instructions an
+// trivially checking if they are all loads.
+//
+//===----------------------------------------------------------------------===//
+
+#include "NVPTX.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/MachineOperand.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
+#include "llvm/Support/ErrorHandling.h"
+
+using namespace llvm;
+
+static bool traverseMoveUse(MachineInstr &U, const MachineRegisterInfo &MRI,
+ SmallVectorImpl<MachineInstr *> &RemoveList,
+ SmallVectorImpl<MachineInstr *> &LoadInsts) {
+ switch (U.getOpcode()) {
+ case NVPTX::LD_f32:
+ case NVPTX::LD_f64:
+ case NVPTX::LD_i16:
+ case NVPTX::LD_i32:
+ case NVPTX::LD_i64:
+ case NVPTX::LD_i8:
+ case NVPTX::LDV_f32_v2:
+ case NVPTX::LDV_f32_v4:
+ case NVPTX::LDV_f64_v2:
+ case NVPTX::LDV_f64_v4:
+ case NVPTX::LDV_i16_v2:
+ case NVPTX::LDV_i16_v4:
+ case NVPTX::LDV_i32_v2:
+ case NVPTX::LDV_i32_v4:
+ case NVPTX::LDV_i64_v2:
+ case NVPTX::LDV_i64_v4:
+ case NVPTX::LDV_i8_v2:
+ case NVPTX::LDV_i8_v4: {
+ LoadInsts.push_back(&U);
+ return true;
+ }
+ case NVPTX::cvta_local:
+ case NVPTX::cvta_local_64:
+ case NVPTX::cvta_to_local:
+ case NVPTX::cvta_to_local_64: {
+ for (auto &U2 : MRI.use_instructions(U.operands_begin()->getReg()))
+ if (!traverseMoveUse(U2, MRI, RemoveList, LoadInsts))
+ return false;
+
+ RemoveList.push_back(&U);
+ return true;
+ }
+ default:
+ return false;
+ }
+}
+
+static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
+ SmallVectorImpl<MachineInstr *> &RemoveList) {
+ SmallVector<MachineInstr *, 16> MaybeRemoveList;
+ SmallVector<MachineInstr *, 16> LoadInsts;
+
+ for (auto &U : MRI.use_instructions(Mov.operands_begin()->getReg()))
+ if (!traverseMoveUse(U, MRI, MaybeRemoveList, LoadInsts))
+ return false;
+
+ RemoveList.append(MaybeRemoveList);
+ RemoveList.push_back(&Mov);
+
+ const MachineOperand *ParamSymbol = Mov.uses().begin();
+ assert(ParamSymbol->isSymbol());
+
+ constexpr unsigned LDInstBasePtrOpIdx = 6;
+ constexpr unsigned LDInstAddrSpaceOpIdx = 2;
+ for (auto *LI : LoadInsts) {
+ (LI->uses().begin() + LDInstBasePtrOpIdx)
+ ->ChangeToES(ParamSymbol->getSymbolName());
+ (LI->uses().begin() + LDInstAddrSpaceOpIdx)
+ ->ChangeToImmediate(NVPTX::AddressSpace::Param);
+ }
+ return true;
+}
+
+static bool forwardDeviceParams(MachineFunction &MF) {
+ const auto &MRI = MF.getRegInfo();
+
+ bool Changed = false;
+ SmallVector<MachineInstr *, 16> RemoveList;
+ for (auto &MI : make_early_inc_range(*MF.begin()))
+ if (MI.getOpcode() == NVPTX::MOV32_PARAM ||
+ MI.getOpcode() == NVPTX::MOV64_PARAM)
+ Changed |= eliminateMove(MI, MRI, RemoveList);
+
+ for (auto *MI : RemoveList)
+ MI->eraseFromParent();
+
+ return Changed;
+}
+
+/// ----------------------------------------------------------------------------
+/// Pass (Manager) Boilerplate
+/// ----------------------------------------------------------------------------
+
+namespace llvm {
+void initializeNVPTXForwardParamsPassPass(PassRegistry &);
+} // namespace llvm
+
+namespace {
+struct NVPTXForwardParamsPass : public MachineFunctionPass {
+ static char ID;
+ NVPTXForwardParamsPass() : MachineFunctionPass(ID) {
+ initializeNVPTXForwardParamsPassPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnMachineFunction(MachineFunction &MF) override;
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ MachineFunctionPass::getAnalysisUsage(AU);
+ }
+};
+} // namespace
+
+char NVPTXForwardParamsPass::ID = 0;
+
+INITIALIZE_PASS(NVPTXForwardParamsPass, "nvptx-forward-params",
+ "NVPTX Forward Params", false, false)
+
+bool NVPTXForwardParamsPass::runOnMachineFunction(MachineFunction &MF) {
+ return forwardDeviceParams(MF);
+}
+
+MachineFunctionPass *llvm::createNVPTXForwardParamsPass() {
+ return new NVPTXForwardParamsPass();
+}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 8a5cdd7412bf3..0461ed4712221 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -2197,11 +2197,11 @@ static SDValue selectBaseADDR(SDValue N, SelectionDAG *DAG) {
if (N.getOpcode() == NVPTXISD::Wrapper)
return N.getOperand(0);
- // addrspacecast(MoveParam(arg_symbol) to addrspace(PARAM)) -> arg_symbol
+ // addrspacecast(Wrapper(arg_symbol) to addrspace(PARAM)) -> arg_symbol
if (AddrSpaceCastSDNode *CastN = dyn_cast<AddrSpaceCastSDNode>(N))
if (CastN->getSrcAddressSpace() == ADDRESS_SPACE_GENERIC &&
CastN->getDestAddressSpace() == ADDRESS_SPACE_PARAM &&
- CastN->getOperand(0).getOpcode() == NVPTXISD::MoveParam)
+ CastN->getOperand(0).getOpcode() == NVPTXISD::Wrapper)
return selectBaseADDR(CastN->getOperand(0).getOperand(0), DAG);
if (auto *FIN = dyn_cast<FrameIndexSDNode>(N))
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 684098681d1ab..0ae6f9004b458 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3376,10 +3376,18 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
assert(ObjectVT == Ins[InsIdx].VT &&
"Ins type did not match function type");
SDValue Arg = getParamSymbol(DAG, i, PtrVT);
- SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
- if (p.getNode())
- p.getNode()->setIROrder(i + 1);
- InVals.push_back(p);
+
+ SDValue P;
+ if (isKernelFunction(*F)) {
+ P = DAG.getNode(NVPTXISD::Wrapper, dl, ObjectVT, Arg);
+ P.getNode()->setIROrder(i + 1);
+ } else {
+ P = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
+ P.getNode()->setIROrder(i + 1);
+ P = DAG.getAddrSpaceCast(dl, ObjectVT, P, ADDRESS_SPACE_LOCAL,
+ ADDRESS_SPACE_GENERIC);
+ }
+ InVals.push_back(P);
}
if (!OutChains.empty())
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 36a0a06bdb8aa..6edb0998760b8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -2324,7 +2324,7 @@ def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>;
def SDTCallArgMarkProfile : SDTypeProfile<0, 0, []>;
def SDTCallVoidProfile : SDTypeProfile<0, 1, []>;
def SDTCallValProfile : SDTypeProfile<1, 0, []>;
-def SDTMoveParamProfile : SDTypeProfile<1, 1, []>;
+def SDTMoveParamProfile : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>;
def SDTStoreRetvalProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>;
def SDTStoreRetvalV2Profile : SDTypeProfile<0, 3, [SDTCisInt<0>]>;
def SDTStoreRetvalV4Profile : SDTypeProfile<0, 5, [SDTCisInt<0>]>;
@@ -2688,29 +2688,14 @@ def DeclareScalarRegInst :
".reg .b$size param$a;",
[(DeclareScalarParam (i32 imm:$a), (i32 imm:$size), (i32 1))]>;
-class MoveParamInst<ValueType T, NVPTXRegClass regclass, string asmstr> :
- NVPTXInst<(outs regclass:$dst), (ins regclass:$src),
- !strconcat("mov", asmstr, " \t$dst, $src;"),
- [(set T:$dst, (MoveParam T:$src))]>;
-
class MoveParamSymbolInst<NVPTXRegClass regclass, Operand srcty, ValueType vt,
string asmstr> :
NVPTXInst<(outs regclass:$dst), (ins srcty:$src),
!strconcat("mov", asmstr, " \t$dst, $src;"),
[(set vt:$dst, (MoveParam texternalsym:$src))]>;
-def MoveParamI64 : MoveParamInst<i64, Int64Regs, ".b64">;
-def MoveParamI32 : MoveParamInst<i32, Int32Regs, ".b32">;
-
-def MoveParamSymbolI64 : MoveParamSymbolInst<Int64Regs, i64imm, i64, ".b64">;
-def MoveParamSymbolI32 : MoveParamSymbolInst<Int32Regs, i32imm, i32, ".b32">;
-
-def MoveParamI16 :
- NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
- "cvt.u16.u32 \t$dst, $src;", // ??? Why cvt.u16.u32 ?
- [(set i16:$dst, (MoveParam i16:$src))]>;
-def MoveParamF64 : MoveParamInst<f64, Float64Regs, ".f64">;
-def MoveParamF32 : MoveParamInst<f32, Float32Regs, ".f32">;
+def MOV64_PARAM : MoveParamSymbolInst<Int64Regs, i64imm, i64, ".b64">;
+def MOV32_PARAM : MoveParamSymbolInst<Int32Regs, i32imm, i32, ".b32">;
class PseudoUseParamInst<NVPTXRegClass regclass, ValueType vt> :
NVPTXInst<(outs), (ins regclass:$src),
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index c763b54c8dbfe..6cef245e2d98e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -153,6 +153,7 @@
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/NVPTXAddrSpace.h"
#include <numeric>
#include <queue>
@@ -373,19 +374,19 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
Type *StructType = Arg->getParamByValType();
const DataLayout &DL = Func->getDataLayout();
- uint64_t NewArgAlign =
- TLI->getFunctionParamOptimizedAlign(Func, StructType, DL).value();
- uint64_t CurArgAlign =
- Arg->getAttribute(Attribute::Alignment).getValueAsInt();
+ const Align NewArgAlign =
+ TLI->getFunctionParamOptimizedAlign(Func, StructType, DL);
+ const Align CurArgAlign = Arg->getParamAlign().valueOrOne();
if (CurArgAlign >= NewArgAlign)
return;
- LLVM_DEBUG(dbgs() << "Try to use alignment " << NewArgAlign << " instead of "
- << CurArgAlign << " for " << *Arg << '\n');
+ LLVM_DEBUG(dbgs() << "Try to use alignment " << NewArgAlign.value()
+ << " instead of " << CurArgAlign.value() << " for " << *Arg
+ << '\n');
auto NewAlignAttr =
- Attribute::get(Func->getContext(), Attribute::Alignment, NewArgAlign);
+ Attribute::getWithAlignment(Func->getContext(), NewArgAlign);
Arg->removeAttr(Attribute::Alignment);
Arg->addAttr(NewAlignAttr);
@@ -402,7 +403,6 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
SmallVector<Load> Loads;
std::queue<LoadContext> Worklist;
Worklist.push({ArgInParamAS, 0});
- bool IsGridConstant = isParamGridConstant(*Arg);
while (!Worklist.empty()) {
LoadContext Ctx = Worklist.front();
@@ -411,15 +411,9 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
for (User *CurUser : Ctx.InitialVal->users()) {
if (auto *I = dyn_cast<LoadInst>(CurUser)) {
Loads.push_back({I, Ctx.Offset});
- continue;
- }
-
- if (auto *I = dyn_cast<BitCastInst>(CurUser)) {
- Worklist.push({I, Ctx.Offset});
- continue;
- }
-
- if (auto *I = dyn_cast<GetElementPtrInst>(CurUser)) {
+ } else if (isa<BitCastInst>(CurUser) || isa<AddrSpaceCastInst>(CurUser)) {
+ Worklist.push({cast<Instruction>(CurUser), Ctx.Offset});
+ } else if (auto *I = dyn_cast<GetElementPtrInst>(CurUser)) {
APInt OffsetAccumulated =
APInt::getZero(DL.getIndexSizeInBits(ADDRESS_SPACE_PARAM));
@@ -431,26 +425,13 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
assert(Offset != OffsetLimit && "Expect Offset less than UINT64_MAX");
Worklist.push({I, Ctx.Offset + Offset});
- continue;
}
-
- if (isa<MemTransferInst>(CurUser))
- continue;
-
- // supported for grid_constant
- if (IsGridConstant &&
- (isa<CallInst>(CurUser) || isa<StoreInst>(CurUser) ||
- isa<PtrToIntInst>(CurUser)))
- continue;
-
- llvm_unreachable("All users must be one of: load, "
- "bitcast, getelementptr, call, store, ptrtoint");
}
}
for (Load &CurLoad : Loads) {
- Align NewLoadAlign(std::gcd(NewArgAlign, CurLoad.Offset));
- Align CurLoadAlign(CurLoad.Inst->getAlign());
+ Align NewLoadAlign(std::gcd(NewArgAlign.value(), CurLoad.Offset));
+ Align CurLoadAlign = CurLoad.Inst->getAlign();
CurLoad.Inst->setAlignment(std::max(NewLoadAlign, CurLoadAlign));
}
}
@@ -641,7 +622,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
copyByValParam(*Func, *Arg);
}
-void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
+static void markPointerAsAS(Value *Ptr, const unsigned AS) {
if (Ptr->getType()->getPointerAddressSpace() != ADDRESS_SPACE_GENERIC)
return;
@@ -658,7 +639,7 @@ void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
}
Instruction *PtrInGlobal = new AddrSpaceCastInst(
- Ptr, PointerType::get(Ptr->getContext(), ADDRESS_SPACE_GLOBAL),
+ Ptr, PointerType::get(Ptr->getContext(), AS),
Ptr->getName(), InsertPt);
Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(),
Ptr->getName(), InsertPt);
@@ -667,6 +648,10 @@ void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
PtrInGlobal->setOperand(0, Ptr);
}
+void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
+ markPointerAsAS(Ptr, ADDRESS_SPACE_GLOBAL);
+}
+
// =============================================================================
// Main function for this pass.
// =============================================================================
@@ -724,9 +709,15 @@ bool NVPTXLowerArgs::runOnKernelFunction(const NVPTXTargetMachine &TM,
bool NVPTXLowerArgs::runOnDeviceFunction(const NVPTXTargetMachine &TM,
Function &F) {
LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n");
+
+ const auto *TLI =
+ cast<NVPTXTargetLowering>(TM.getSubtargetImpl()->getTargetLowering());
+
for (Argument &Arg : F.args())
- if (Arg.getType()->isPointerTy() && Arg.hasByValAttr())
- handleByValParam(TM, &Arg);
+ if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) {
+ markPointerAsAS(&Arg, ADDRESS_SPACE_LOCAL);
+ adjustByValArgAlignment(&Arg, &Arg, TLI);
+ }
return true;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index f2afa6fc20bfa..229fecf2d3b10 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -100,6 +100,7 @@ void initializeNVPTXLowerUnreachablePass(PassRegistry &);
void initializeNVPTXCtorDtorLoweringLegacyPass(PassRegistry &);
void initializeNVPTXLowerArgsPass(PassRegistry &);
void initializeNVPTXProxyRegErasurePass(PassRegistry &);
+void initializeNVPTXForwardParamsPassPass(PassRegistry &);
void initializeNVVMIntrRangePass(PassRegistry &);
void initializeNVVMReflectPass(PassRegistry &);
void initializeNVPTXAAWrapperPassPass(PassRegistry &);
@@ -127,6 +128,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
initializeNVPTXCtorDtorLoweringLegacyPass(PR);
initializeNVPTXLowerAggrCopiesPass(PR);
initializeNVPTXProxyRegErasurePass(PR);
+ initializeNVPTXForwardParamsPassPass(PR);
initializeNVPTXDAGToDAGISelLegacyPass(PR);
initializeNVPTXAAWrapperPassPass(PR);
initializeNVPTXExternalAAWrapperPass(PR);
@@ -429,6 +431,7 @@ bool NVPTXPassConfig::addInstSelector() {
}
void NVPTXPassConfig::addPreRegAlloc() {
+ addPass(createNVPTXForwardParamsPass());
// Remove Proxy Register pseudo instructions used to keep `callseq_end` alive.
addPass(createNVPTXProxyRegErasurePass());
}
diff --git a/llvm/test/CodeGen/NVPTX/forward-ld-param.ll b/llvm/test/CodeGen/NVPTX/forward-ld-param.ll
new file mode 100644
index 0000000000000..c4e56d197edc0
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/forward-ld-param.ll
@@ -0,0 +1,142 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+
+target triple = "nvptx64-nvidia-cuda"
+
+define i32 @test_ld_param_const(ptr byval(i32) %a) {
+; CHECK-LABEL: test_ld_param_const(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_ld_param_const_param_0+4];
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %p2 = getelementptr i32, ptr %a, i32 1
+ %ld = load i32, ptr %p2
+ ret i32 %ld
+}
+
+define i32 @test_ld_param_non_const(ptr byval([10 x i32]) %a, i32 %b) {
+; CHECK-LABEL: test_ld_param_non_const(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: mov.b64 %rd1, test_ld_param_non_const_param_0;
+; CHECK-NEXT: cvta.local.u64 %rd2, %rd1;
+; CHECK-NEXT: cvta.to.local.u64 %rd3, %rd2;
+; CHECK-NEXT: ld.param.s32 %rd4, [test_ld_param_non_const_param_1];
+; CHECK-NEXT: add.s64 %rd5, %rd3, %rd4;
+; CHECK-NEXT: ld.local.u32 %r1, [%rd5];
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %p2 = getelementptr i8, ptr %a, i32 %b
+ %ld = load i32, ptr %p2
+ ret i32 %ld
+}
+
+declare void @escape(ptr)
+declare void @byval_user(ptr byval(i32))
+
+define void @test_ld_param_escaping(ptr byval(i32) %a) {
+; CHECK-LABEL: test_ld_param_escaping(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: mov.b64 %rd1, test_ld_param_escaping_param_0;
+; CHECK-NEXT: cvta.local.u64 %rd2, %rd1;
+; CHECK-NEXT: { // callseq 0, 0
+; CHECK-NEXT: .param .b64 param0;
+; CHECK-NEXT: st.param.b64 [param0], %rd2;
+; CHECK-NEXT: call.uni
+; CHECK-NEXT: escape,
+; CHECK-NEXT: (
+; CHECK-NEXT: param0
+; CHECK-NEXT: );
+; CHECK-NEXT: } // callseq 0
+; CHECK-NEXT: ret;
+ call void @escape(ptr %a)
+ ret void
+}
+
+define void @test_ld_param_byval(ptr byval(i32) %a) {
+; CHECK-LABEL: test_ld_param_byval(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_ld_param_byval_param_0];
+; CHECK-NEXT: { // callseq 1, 0
+; CHECK-NEXT: .param .align 4 .b8 param0[4];
+; CHECK-NEXT: st.param.b32 [param0], %r1;
+; CHECK-NEXT: call.uni
+; CHECK-NEXT: byval_user,
+; CHECK-NEXT: (
+; CHECK-NEXT: param0
+; CHECK-NEXT: );
+; CHECK-NEXT: } // callseq 1
+; CHECK-NEXT: ret;
+ call void @byval_user(ptr %a)
+ ret void
+}
+
+define i32 @test_modify_param(ptr byval([10 x i32]) %a, i32 %b, i32 %c ) {
+; CHECK-LABEL: test_modify_param(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-NEXT: .reg .b64 %rd<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: mov.b64 %rd1, test_modify_param_param_0;
+; CHECK-NEXT: cvta.local.u64 %rd2, %rd1;
+; CHECK-NEXT: cvta.to.local.u64 %rd3, %rd2;
+; CHECK-NEXT: ld.param.u32 %r1, [test_modify_param_param_1];
+; CHECK-NEXT: ld.param.u32 %r2, [test_modify_param_param_2];
+; CHECK-NEXT: st.local.u32 [%rd3+2], %r1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT: ret;
+ %p2 = getelementptr i8, ptr %a, i32 2
+ store volatile i32 %b, ptr %p2
+ ret i32 %c
+}
+
+define i32 @test_multi_block(ptr byval([10 x i32]) %a, i1 %p) {
+; CHECK-LABEL: test_multi_block(
+; CHECK: {
+; CHECK-NEXT: .reg .pred %p<3>;
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u8 %rs1, [test_multi_block_param_1];
+; CHECK-NEXT: and.b16 %rs2, %rs1, 1;
+; CHECK-NEXT: setp.eq.b16 %p1, %rs2, 1;
+; CHECK-NEXT: not.pred %p2, %p1;
+; CHECK-NEXT: @%p2 bra $L__BB5_2;
+; CHECK-NEXT: // %bb.1: // %if
+; CHECK-NEXT: ld.param.u32 %r4, [test_multi_block_param_0+4];
+; CHECK-NEXT: bra.uni $L__BB5_3;
+; CHECK-NEXT: $L__BB5_2: // %else
+; CHECK-NEXT: ld.param.u32 %r4, [test_multi_block_param_0+8];
+; CHECK-NEXT: $L__BB5_3: // %end
+; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT: ret;
+ br i1 %p, label %if, label %else
+if:
+ %p2 = getelementptr i8, ptr %a, i32 4
+ %v2 = load i32, ptr %p2
+ br label %end
+else:
+ %p3 = getelementptr i8, ptr %a, i32 8
+ %v3 = load i32, ptr %p3
+ br label %end
+end:
+ %v = phi i32 [ %v2, %if ], [ %v3, %else ]
+ ret i32 %v
+}
diff --git a/llvm/test/CodeGen/NVPTX/i128-array.ll b/llvm/test/CodeGen/NVPTX/i128-array.ll
index 348df8dcc7373..baa18880de840 100644
--- a/llvm/test/CodeGen/NVPTX/i128-array.ll
+++ b/llvm/test/CodeGen/NVPTX/i128-array.ll
@@ -27,16 +27,15 @@ define [2 x i128] @foo(i64 %a, i32 %b) {
define [2 x i128] @foo2(ptr byval([2 x i128]) %a) {
; CHECK-LABEL: foo2(
; CHECK: {
-; CHECK-NEXT: .reg .b64 %rd<6>;
+; CHECK-NEXT: .reg .b64 %rd<9>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: mov.b64 %rd1, foo2_param_0;
-; CHECK-NEXT: ld.param.u64 %rd2, [foo2_param_0+8];
-; CHECK-NEXT: ld.param.u64 %rd3, [foo2_param_0];
-; CHECK-NEXT: ld.param.u64 %rd4, [foo2_param_0+24];
-; CHECK-NEXT: ld.param.u64 %rd5, [foo2_param_0+16];
-; CHECK-NEXT: st.param.v2.b64 [func_retval0], {%rd3, %rd2};
-; CHECK-NEXT: st.param.v2.b64 [func_retval0+16], {%rd5, %rd4};
+; CHECK-NEXT: ld.param.u64 %rd5, [foo2_param_0+8];
+; CHECK-NEXT: ld.param.u64 %rd6, [foo2_param_0];
+; CHECK-NEXT: ld.param.u64 %rd7, [foo2_param_0+24];
+; CHECK-NEXT: ld.param.u64 %rd8, [foo2_param_0+16];
+; CHECK-NEXT: st.param.v2.b64 [func_retval0], {%rd6, %rd5};
+; CHECK-NEXT: st.param.v2.b64 [func_retval0+16], {%rd8, %rd7};
; CHECK-NEXT: ret;
%ptr0 = getelementptr [2 x i128], ptr %a, i64 0, i32 0
%1 = load i128, i128* %ptr0
diff --git a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
index fe15be5663be1..90f9306d036cd 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
@@ -12,9 +12,8 @@ define dso_local noundef i32 @non_kernel_function(ptr nocapture noundef readonly
; OPT-LABEL: define dso_local noundef i32 @non_kernel_function(
; OPT-SAME: ptr noundef readonly byval([[STRUCT_UINT4:%.*]]) align 16 captures(none) [[A:%.*]], i1 noundef zeroext [[B:%.*]], i32 noundef [[C:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
; OPT-NEXT: [[ENTRY:.*:]]
-; OPT-NEXT: [[A1:%.*]] = alloca [[STRUCT_UINT4]], align 16
-; OPT-NEXT: [[A2:%.*]] = addrspacecast ptr [[A]] to ptr addrspace(101)
-; OPT-NEXT: call void @llvm.memcpy.p0.p101.i64(ptr align 16 [[A1]], ptr addrspace(101) align 16 [[A2]], i64 16, i1 false)
+; OPT-NEXT: [[A2:%.*]] = addrspacecast ptr [[A]] to ptr addrspace(5)
+; OPT-NEXT: [[A1:%.*]] = addrspacecast ptr addrspace(5) [[A2]] to ptr
; OPT-NEXT: [[A_:%.*]] = select i1 [[B]], ptr [[A1]], ptr addrspacecast (ptr addrspace(1) @gi to ptr)
; OPT-NEXT: [[IDX_EXT:%.*]] = sext i32 [[C]] to i64
; OPT-NEXT: [[ADD_PTR:%.*]] = getelementptr inbounds i8, ptr [[A_]], i64 [[IDX_EXT]]
@@ -23,38 +22,29 @@ define dso_local noundef i32 @non_kernel_function(ptr nocapture noundef readonly
;
; PTX-LABEL: non_kernel_function(
; PTX: {
-; PTX-NEXT: .local .align 16 .b8 __local_depot0[16];
-; PTX-NEXT: .reg .b64 %SP;
-; PTX-NEXT: .reg .b64 %SPL;
; PTX-NEXT: .reg .pred %p<2>;
; PTX-NEXT: .reg .b16 %rs<3>;
; PTX-NEXT: .reg .b32 %r<11>;
-; PTX-NEXT: .reg .b64 %rd<10>;
+; PTX-NEXT: .reg .b64 %rd<8>;
; PTX-EMPTY:
; PTX-NEXT: // %bb.0: // %entry
-; PTX-NEXT: mov.u64 %SPL, __local_depot0;
-; PTX-NEXT: cvta.local.u64 %SP, %SPL;
+; PTX-NEXT: mov.b64 %rd1, non_kernel_function_param_0;
+; PTX-NEXT: cvta.local.u64 %rd2, %rd1;
; PTX-NEXT: ld.param.u8 %rs1, [non_kernel_function_param_1];
; PTX-NEXT: and.b16 %rs2, %rs1, 1;
; PTX-NEXT: setp.eq.b16 %p1, %rs2, 1;
-; PTX-NEXT: add.u64 %rd1, %SP, 0;
-; PTX-NEXT: add.u64 %rd2, %SPL, 0;
-; PTX-NEXT: ld.param.s32 %rd3, [non_kernel_function_param_2];
-; PTX-NEXT: ld.param.u64 %rd4, [non_kernel_function_param_0+8];
-; PTX-NEXT: st.local.u64 [%rd2+8], %rd4;
-; PTX-NEXT: ld.param.u64 %rd5, [non_kernel_function_param_0];
-; PTX-NEXT: st.local.u64 [%rd2], %rd5;
-; PTX-NEXT: mov.u64 %rd6, gi;
-; PTX-NEXT: cvta.global.u64 %rd7, %rd6;
-; PTX-NEXT: selp.b64 %rd8, %rd1, %rd7, %p1;
-; PTX-NEXT: add.s64 %rd9, %rd8, %rd3;
-; PTX-NEXT: ld.u8 %r1, [%rd9];
-; PTX-NEXT: ld.u8 %r2, [%rd9+1];
+; PTX-NEXT: mov.u64 %rd3, gi;
+; PTX-NEXT: cvta.global.u64 %rd4, %rd3;
+; PTX-NEXT: selp.b64 %rd5, %rd2, %rd4, %p1;
+; PTX-NEXT: ld.param.s32 %rd6, [non_kernel_function_param_2];
+; PTX-NEXT: add.s64 %rd7, %rd5, %rd6;
+; PTX-NEXT: ld.u8 %r1, [%rd7];
+; PTX-NEXT: ld.u8 %r2, [%rd7+1];
; PTX-NEXT: shl.b32 %r3, %r2, 8;
; PTX-NEXT: or.b32 %r4, %r3, %r1;
-; PTX-NEXT: ld.u8 %r5, [%rd9+2];
+; PTX-NEXT: ld.u8 %r5, [%rd7+2];
; PTX-NEXT: shl.b32 %r6, %r5, 16;
-; PTX-NEXT: ld.u8 %r7, [%rd9+3];
+; PTX-NEXT: ld.u8 %r7, [%rd7+3];
; PTX-NEXT: shl.b32 %r8, %r7, 24;
; PTX-NEXT: or.b32 %r9, %r8, %r6;
; PTX-NEXT: or.b32 %r10, %r9, %r4;
@@ -91,6 +81,7 @@ define ptx_kernel void @grid_const_int(ptr byval(i32) align 4 %input1, i32 %inpu
; OPT-NEXT: [[ADD:%.*]] = add i32 [[TMP]], [[INPUT2]]
; OPT-NEXT: store i32 [[ADD]], ptr [[OUT3]], align 4
; OPT-NEXT: ret void
+;
%tmp = load i32, ptr %input1, align 4
%add = add i32 %tmp, %input2
store i32 %add, ptr %out
@@ -125,6 +116,7 @@ define ptx_kernel void @grid_const_struct(ptr byval(%struct.s) align 4 %input, p
; OPT-NEXT: [[ADD:%.*]] = add i32 [[TMP1]], [[TMP2]]
; OPT-NEXT: store i32 [[ADD]], ptr [[OUT5]], align 4
; OPT-NEXT: ret void
+;
%gep1 = getelementptr inbounds %struct.s, ptr %input, i32 0, i32 0
%gep2 = getelementptr inbounds %struct.s, ptr %input, i32 0, i32 1
%int1 = load i32, ptr %gep1
@@ -165,6 +157,7 @@ define ptx_kernel void @grid_const_escape(ptr byval(%struct.s) align 4 %input) {
; OPT-NEXT: [[INPUT_PARAM_GEN:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[INPUT_PARAM]])
; OPT-NEXT: [[CALL:%.*]] = call i32 @escape(ptr [[INPUT_PARAM_GEN]])
; OPT-NEXT: ret void
+;
%call = call i32 @escape(ptr %input)
ret void
}
@@ -222,6 +215,7 @@ define ptx_kernel void @multiple_grid_const_escape(ptr byval(%struct.s) align 4
; OPT-NEXT: store i32 [[A]], ptr [[A_ADDR]], align 4
; OPT-NEXT: [[CALL:%.*]] = call i32 @escape3(ptr [[INPUT_PARAM_GEN]], ptr [[A_ADDR]], ptr [[B_PARAM_GEN]])
; OPT-NEXT: ret void
+;
%a.addr = alloca i32, align 4
store i32 %a, ptr %a.addr, align 4
%call = call i32 @escape3(ptr %input, ptr %a.addr, ptr %b)
@@ -249,6 +243,7 @@ define ptx_kernel void @grid_const_memory_escape(ptr byval(%struct.s) align 4 %i
; OPT-NEXT: [[INPUT1:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[INPUT_PARAM]])
; OPT-NEXT: store ptr [[INPUT1]], ptr [[ADDR5]], align 8
; OPT-NEXT: ret void
+;
store ptr %input, ptr %addr, align 8
ret void
}
@@ -282,6 +277,7 @@ define ptx_kernel void @grid_const_inlineasm_escape(ptr byval(%struct.s) align 4
; OPT-NEXT: [[TMP2:%.*]] = call i64 asm "add.s64 $0, $1, $2
; OPT-NEXT: store i64 [[TMP2]], ptr [[RESULT5]], align 8
; OPT-NEXT: ret void
+;
%tmpptr1 = getelementptr inbounds %struct.s, ptr %input, i32 0, i32 0
%tmpptr2 = getelementptr inbounds %struct.s, ptr %input, i32 0, i32 1
%1 = call i64 asm "add.s64 $0, $1, $2;", "=l,l,l"(ptr %tmpptr1, ptr %tmpptr2) #1
@@ -330,6 +326,7 @@ define ptx_kernel void @grid_const_partial_escape(ptr byval(i32) %input, ptr %ou
; OPT-NEXT: store i32 [[TWICE]], ptr [[OUTPUT5]], align 4
; OPT-NEXT: [[CALL:%.*]] = call i32 @escape(ptr [[INPUT1_GEN]])
; OPT-NEXT: ret void
+;
%val = load i32, ptr %input
%twice = add i32 %val, %val
store i32 %twice, ptr %output
@@ -383,6 +380,7 @@ define ptx_kernel i32 @grid_const_partial_escapemem(ptr byval(%struct.s) %input,
; OPT-NEXT: [[ADD:%.*]] = add i32 [[VAL1]], [[VAL2]]
; OPT-NEXT: [[CALL2:%.*]] = call i32 @escape(ptr [[PTR1]])
; OPT-NEXT: ret i32 [[ADD]]
+;
%ptr1 = getelementptr inbounds %struct.s, ptr %input, i32 0, i32 0
%val1 = load i32, ptr %ptr1
%ptr2 = getelementptr inbounds %struct.s, ptr %input, i32 0, i32 1
@@ -435,6 +433,7 @@ define ptx_kernel void @grid_const_phi(ptr byval(%struct.s) align 4 %input1, ptr
; OPT-NEXT: [[VALLOADED:%.*]] = load i32, ptr [[PTRNEW]], align 4
; OPT-NEXT: store i32 [[VALLOADED]], ptr [[INOUT2]], align 4
; OPT-NEXT: ret void
+;
%val = load i32, ptr %inout
%less = icmp slt i32 %val, 0
@@ -500,6 +499,7 @@ define ptx_kernel void @grid_const_phi_ngc(ptr byval(%struct.s) align 4 %input1,
; OPT-NEXT: [[VALLOADED:%.*]] = load i32, ptr [[PTRNEW]], align 4
; OPT-NEXT: store i32 [[VALLOADED]], ptr [[INOUT2]], align 4
; OPT-NEXT: ret void
+;
%val = load i32, ptr %inout
%less = icmp slt i32 %val, 0
br i1 %less, label %first, label %second
@@ -553,6 +553,7 @@ define ptx_kernel void @grid_const_select(ptr byval(i32) align 4 %input1, ptr by
; OPT-NEXT: [[VALLOADED:%.*]] = load i32, ptr [[PTRNEW]], align 4
; OPT-NEXT: store i32 [[VALLOADED]], ptr [[INOUT2]], align 4
; OPT-NEXT: ret void
+;
%val = load i32, ptr %inout
%less = icmp slt i32 %val, 0
%ptrnew = select i1 %less, ptr %input1, ptr %input2
@@ -584,6 +585,7 @@ define ptx_kernel i32 @grid_const_ptrtoint(ptr byval(i32) %input) {
; OPT-NEXT: [[PTRVAL:%.*]] = ptrtoint ptr [[INPUT1]] to i32
; OPT-NEXT: [[KEEPALIVE:%.*]] = add i32 [[INPUT3]], [[PTRVAL]]
; OPT-NEXT: ret i32 [[KEEPALIVE]]
+;
%val = load i32, ptr %input
%ptrval = ptrtoint ptr %input to i32
%keepalive = add i32 %val, %ptrval
diff --git a/llvm/test/CodeGen/NVPTX/lower-args.ll b/llvm/test/CodeGen/NVPTX/lower-args.ll
index 23cf1a85789e4..2f8875257e67f 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args.ll
@@ -15,9 +15,9 @@ target triple = "nvptx64-nvidia-cuda"
; COMMON-LABEL: load_alignment
define void @load_alignment(ptr nocapture readonly byval(%class.outer) align 8 %arg) {
entry:
-; IR: call void @llvm.memcpy.p0.p101.i64(ptr align 8
-; PTX: ld.param.u64
-; PTX-NOT: ld.param.u8
+; IR: addrspacecast ptr %arg to ptr addrspace(5)
+; PTX: ld.local.u64
+; PTX-NOT: ld.local.u8
%arg.idx.val = load ptr, ptr %arg, align 8
%arg.idx1 = getelementptr %class.outer, ptr %arg, i64 0, i32 0, i32 1
%arg.idx1.val = load ptr, ptr %arg.idx1, align 8
@@ -37,28 +37,21 @@ entry:
; COMMON-LABEL: load_padding
define void @load_padding(ptr nocapture readonly byval(%class.padded) %arg) {
; PTX: {
-; PTX-NEXT: .local .align 8 .b8 __local_depot1[8];
-; PTX-NEXT: .reg .b64 %SP;
-; PTX-NEXT: .reg .b64 %SPL;
-; PTX-NEXT: .reg .b64 %rd<6>;
+; PTX-NEXT: .reg .b64 %rd<5>;
; PTX-EMPTY:
; PTX-NEXT: // %bb.0:
-; PTX-NEXT: mov.u64 %SPL, __local_depot1;
-; PTX-NEXT: cvta.local.u64 %SP, %SPL;
-; PTX-NEXT: add.u64 %rd1, %SP, 0;
-; PTX-NEXT: add.u64 %rd2, %SPL, 0;
-; PTX-NEXT: ld.param.u64 %rd3, [load_padding_param_0];
-; PTX-NEXT: st.local.u64 [%rd2], %rd3;
+; PTX-NEXT: mov.b64 %rd1, load_padding_param_0;
+; PTX-NEXT: cvta.local.u64 %rd2, %rd1;
; PTX-NEXT: { // callseq 1, 0
; PTX-NEXT: .param .b64 param0;
-; PTX-NEXT: st.param.b64 [param0], %rd1;
+; PTX-NEXT: st.param.b64 [param0], %rd2;
; PTX-NEXT: .param .b64 retval0;
; PTX-NEXT: call.uni (retval0),
; PTX-NEXT: escape,
; PTX-NEXT: (
; PTX-NEXT: param0
; PTX-NEXT: );
-; PTX-NEXT: ld.param.b64 %rd4, [retval0];
+; PTX-NEXT: ld.param.b64 %rd3, [retval0];
; PTX-NEXT: } // callseq 1
; PTX-NEXT: ret;
%tmp = call ptr @escape(ptr nonnull align 16 %arg)
diff --git a/llvm/test/CodeGen/NVPTX/variadics-backend.ll b/llvm/test/CodeGen/NVPTX/variadics-backend.ll
index 377528b94f505..eaf0ce58750b4 100644
--- a/llvm/test/CodeGen/NVPTX/variadics-backend.ll
+++ b/llvm/test/CodeGen/NVPTX/variadics-backend.ll
@@ -338,18 +338,18 @@ define dso_local i32 @variadics4(ptr noundef byval(%struct.S2) align 8 %first, .
; CHECK-PTX-LABEL: variadics4(
; CHECK-PTX: {
; CHECK-PTX-NEXT: .reg .b32 %r<2>;
-; CHECK-PTX-NEXT: .reg .b64 %rd<9>;
+; CHECK-PTX-NEXT: .reg .b64 %rd<12>;
; CHECK-PTX-EMPTY:
; CHECK-PTX-NEXT: // %bb.0: // %entry
-; CHECK-PTX-NEXT: ld.param.u64 %rd1, [variadics4_param_1];
-; CHECK-PTX-NEXT: add.s64 %rd2, %rd1, 7;
-; CHECK-PTX-NEXT: and.b64 %rd3, %rd2, -8;
-; CHECK-PTX-NEXT: ld.u64 %rd4, [%rd3];
-; CHECK-PTX-NEXT: ld.param.u64 %rd5, [variadics4_param_0];
-; CHECK-PTX-NEXT: ld.param.u64 %rd6, [variadics4_param_0+8];
-; CHECK-PTX-NEXT: add.s64 %rd7, %rd5, %rd6;
-; CHECK-PTX-NEXT: add.s64 %rd8, %rd7, %rd4;
-; CHECK-PTX-NEXT: cvt.u32.u64 %r1, %rd8;
+; CHECK-PTX-NEXT: ld.param.u64 %rd4, [variadics4_param_1];
+; CHECK-PTX-NEXT: add.s64 %rd5, %rd4, 7;
+; CHECK-PTX-NEXT: and.b64 %rd6, %rd5, -8;
+; CHECK-PTX-NEXT: ld.u64 %rd7, [%rd6];
+; CHECK-PTX-NEXT: ld.param.u64 %rd8, [variadics4_param_0];
+; CHECK-PTX-NEXT: ld.param.u64 %rd9, [variadics4_param_0+8];
+; CHECK-PTX-NEXT: add.s64 %rd10, %rd8, %rd9;
+; CHECK-PTX-NEXT: add.s64 %rd11, %rd10, %rd7;
+; CHECK-PTX-NEXT: cvt.u32.u64 %r1, %rd11;
; CHECK-PTX-NEXT: st.param.b32 [func_retval0], %r1;
; CHECK-PTX-NEXT: ret;
entry:
More information about the llvm-commits
mailing list