[llvm] [NVPTX] Improve device function byval parameter lowering (PR #129188)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 28 12:47:23 PST 2025


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/129188

>From 5f80afe54d4d1d69ca429f25b5bc57c67e9995a3 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Fri, 28 Feb 2025 18:14:23 +0000
Subject: [PATCH 1/3] pre-commit tests -- use update_llc_test_checks.py

---
 llvm/test/CodeGen/NVPTX/lower-args.ll | 268 ++++++++++++++++++++------
 1 file changed, 204 insertions(+), 64 deletions(-)

diff --git a/llvm/test/CodeGen/NVPTX/lower-args.ll b/llvm/test/CodeGen/NVPTX/lower-args.ll
index 23cf1a85789e4..66bd5e52b5f11 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args.ll
@@ -1,7 +1,8 @@
-; RUN: opt < %s -S -nvptx-lower-args --mtriple nvptx64-nvidia-cuda | FileCheck %s --check-prefixes COMMON,IR,IRC
-; RUN: opt < %s -S -nvptx-lower-args --mtriple nvptx64-nvidia-nvcl | FileCheck %s --check-prefixes COMMON,IR,IRO
-; RUN: llc < %s -mcpu=sm_20 --mtriple nvptx64-nvidia-cuda | FileCheck %s --check-prefixes COMMON,PTX,PTXC
-; RUN: llc < %s -mcpu=sm_20 --mtriple nvptx64-nvidia-nvcl| FileCheck %s --check-prefixes COMMON,PTX,PTXO
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -S -nvptx-lower-args --mtriple nvptx64-nvidia-cuda | FileCheck %s --check-prefixes IR,IRC
+; RUN: opt < %s -S -nvptx-lower-args --mtriple nvptx64-nvidia-nvcl | FileCheck %s --check-prefixes IR,IRO
+; RUN: llc < %s -mcpu=sm_20 --mtriple nvptx64-nvidia-cuda | FileCheck %s --check-prefixes PTX,PTXC
+; RUN: llc < %s -mcpu=sm_20 --mtriple nvptx64-nvidia-nvcl| FileCheck %s --check-prefixes PTX,PTXO
 ; RUN: %if ptxas %{ llc < %s -mcpu=sm_20 | %ptxas-verify %}
 
 target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
@@ -12,12 +13,60 @@ target triple = "nvptx64-nvidia-cuda"
 %class.padded = type { i8, i32 }
 
 ; Check that nvptx-lower-args preserves arg alignment
-; COMMON-LABEL: load_alignment
 define void @load_alignment(ptr nocapture readonly byval(%class.outer) align 8 %arg) {
+; IR-LABEL: define void @load_alignment(
+; IR-SAME: ptr readonly byval([[CLASS_OUTER:%.*]]) align 8 captures(none) [[ARG:%.*]]) {
+; IR-NEXT:  [[ENTRY:.*:]]
+; IR-NEXT:    [[ARG1:%.*]] = alloca [[CLASS_OUTER]], align 8
+; IR-NEXT:    [[ARG2:%.*]] = addrspacecast ptr [[ARG]] to ptr addrspace(101)
+; IR-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 8 [[ARG1]], ptr addrspace(101) align 8 [[ARG2]], i64 24, i1 false)
+; IR-NEXT:    [[ARG_IDX_VAL:%.*]] = load ptr, ptr [[ARG1]], align 8
+; IR-NEXT:    [[ARG_IDX1:%.*]] = getelementptr [[CLASS_OUTER]], ptr [[ARG1]], i64 0, i32 0, i32 1
+; IR-NEXT:    [[ARG_IDX1_VAL:%.*]] = load ptr, ptr [[ARG_IDX1]], align 8
+; IR-NEXT:    [[ARG_IDX2:%.*]] = getelementptr [[CLASS_OUTER]], ptr [[ARG1]], i64 0, i32 1
+; IR-NEXT:    [[ARG_IDX2_VAL:%.*]] = load i32, ptr [[ARG_IDX2]], align 8
+; IR-NEXT:    [[ARG_IDX_VAL_VAL:%.*]] = load i32, ptr [[ARG_IDX_VAL]], align 4
+; IR-NEXT:    [[ADD_I:%.*]] = add nsw i32 [[ARG_IDX_VAL_VAL]], [[ARG_IDX2_VAL]]
+; IR-NEXT:    store i32 [[ADD_I]], ptr [[ARG_IDX1_VAL]], align 4
+; IR-NEXT:    [[TMP:%.*]] = call ptr @escape(ptr nonnull [[ARG_IDX2]])
+; IR-NEXT:    ret void
+;
+; PTX-LABEL: load_alignment(
+; PTX:       {
+; PTX-NEXT:    .local .align 8 .b8 __local_depot0[24];
+; PTX-NEXT:    .reg .b64 %SP;
+; PTX-NEXT:    .reg .b64 %SPL;
+; PTX-NEXT:    .reg .b32 %r<4>;
+; PTX-NEXT:    .reg .b64 %rd<10>;
+; PTX-EMPTY:
+; PTX-NEXT:  // %bb.0: // %entry
+; PTX-NEXT:    mov.u64 %SPL, __local_depot0;
+; PTX-NEXT:    add.u64 %rd2, %SPL, 0;
+; PTX-NEXT:    ld.param.u64 %rd3, [load_alignment_param_0+16];
+; PTX-NEXT:    st.local.u64 [%rd2+16], %rd3;
+; PTX-NEXT:    ld.param.u64 %rd4, [load_alignment_param_0+8];
+; PTX-NEXT:    st.local.u64 [%rd2+8], %rd4;
+; PTX-NEXT:    ld.param.u64 %rd5, [load_alignment_param_0];
+; PTX-NEXT:    st.local.u64 [%rd2], %rd5;
+; PTX-NEXT:    add.s64 %rd6, %rd2, 16;
+; PTX-NEXT:    cvta.local.u64 %rd7, %rd6;
+; PTX-NEXT:    cvt.u32.u64 %r1, %rd3;
+; PTX-NEXT:    ld.u32 %r2, [%rd5];
+; PTX-NEXT:    add.s32 %r3, %r2, %r1;
+; PTX-NEXT:    st.u32 [%rd4], %r3;
+; PTX-NEXT:    { // callseq 0, 0
+; PTX-NEXT:    .param .b64 param0;
+; PTX-NEXT:    st.param.b64 [param0], %rd7;
+; PTX-NEXT:    .param .b64 retval0;
+; PTX-NEXT:    call.uni (retval0),
+; PTX-NEXT:    escape,
+; PTX-NEXT:    (
+; PTX-NEXT:    param0
+; PTX-NEXT:    );
+; PTX-NEXT:    ld.param.b64 %rd8, [retval0];
+; PTX-NEXT:    } // callseq 0
+; PTX-NEXT:    ret;
 entry:
-; IR: call void @llvm.memcpy.p0.p101.i64(ptr align 8
-; PTX: ld.param.u64
-; PTX-NOT: ld.param.u8
   %arg.idx.val = load ptr, ptr %arg, align 8
   %arg.idx1 = getelementptr %class.outer, ptr %arg, i64 0, i32 0, i32 1
   %arg.idx1.val = load ptr, ptr %arg.idx1, align 8
@@ -34,8 +83,16 @@ entry:
 }
 
 ; Check that nvptx-lower-args copies padding as the struct may have been a union
-; COMMON-LABEL: load_padding
 define void @load_padding(ptr nocapture readonly byval(%class.padded) %arg) {
+; IR-LABEL: define void @load_padding(
+; IR-SAME: ptr readonly byval([[CLASS_PADDED:%.*]]) captures(none) [[ARG:%.*]]) {
+; IR-NEXT:    [[ARG1:%.*]] = alloca [[CLASS_PADDED]], align 8
+; IR-NEXT:    [[ARG2:%.*]] = addrspacecast ptr [[ARG]] to ptr addrspace(101)
+; IR-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 8 [[ARG1]], ptr addrspace(101) align 8 [[ARG2]], i64 8, i1 false)
+; IR-NEXT:    [[TMP:%.*]] = call ptr @escape(ptr nonnull align 16 [[ARG1]])
+; IR-NEXT:    ret void
+;
+; PTX-LABEL: load_padding(
 ; PTX:       {
 ; PTX-NEXT:    .local .align 8 .b8 __local_depot1[8];
 ; PTX-NEXT:    .reg .b64 %SP;
@@ -45,8 +102,8 @@ define void @load_padding(ptr nocapture readonly byval(%class.padded) %arg) {
 ; PTX-NEXT:  // %bb.0:
 ; PTX-NEXT:    mov.u64 %SPL, __local_depot1;
 ; PTX-NEXT:    cvta.local.u64 %SP, %SPL;
-; PTX-NEXT:    add.u64         %rd1, %SP, 0;
-; PTX-NEXT:    add.u64         %rd2, %SPL, 0;
+; PTX-NEXT:    add.u64 %rd1, %SP, 0;
+; PTX-NEXT:    add.u64 %rd2, %SPL, 0;
 ; PTX-NEXT:    ld.param.u64 %rd3, [load_padding_param_0];
 ; PTX-NEXT:    st.local.u64 [%rd2], %rd3;
 ; PTX-NEXT:    { // callseq 1, 0
@@ -65,55 +122,115 @@ define void @load_padding(ptr nocapture readonly byval(%class.padded) %arg) {
   ret void
 }
 
-; COMMON-LABEL: ptr_generic
-define ptx_kernel void @ptr_generic(ptr %out, ptr %in) {
-; IRC:  %in3 = addrspacecast ptr %in to ptr addrspace(1)
-; IRC:  %in4 = addrspacecast ptr addrspace(1) %in3 to ptr
-; IRC:  %out1 = addrspacecast ptr %out to ptr addrspace(1)
-; IRC:  %out2 = addrspacecast ptr addrspace(1) %out1 to ptr
-; PTXC: cvta.to.global.u64
-; PTXC: cvta.to.global.u64
-; PTXC: ld.global.u32
-; PTXC: st.global.u32
-
 ; OpenCL can't make assumptions about incoming pointer, so we should generate
 ; generic pointers load/store.
-; IRO-NOT: addrspacecast
-; PTXO-NOT: cvta.to.global
-; PTXO: ld.u32
-; PTXO: st.u32
+define ptx_kernel void @ptr_generic(ptr %out, ptr %in) {
+; IRC-LABEL: define ptx_kernel void @ptr_generic(
+; IRC-SAME: ptr [[OUT:%.*]], ptr [[IN:%.*]]) {
+; IRC-NEXT:    [[IN3:%.*]] = addrspacecast ptr [[IN]] to ptr addrspace(1)
+; IRC-NEXT:    [[IN4:%.*]] = addrspacecast ptr addrspace(1) [[IN3]] to ptr
+; IRC-NEXT:    [[OUT1:%.*]] = addrspacecast ptr [[OUT]] to ptr addrspace(1)
+; IRC-NEXT:    [[OUT2:%.*]] = addrspacecast ptr addrspace(1) [[OUT1]] to ptr
+; IRC-NEXT:    [[V:%.*]] = load i32, ptr [[IN4]], align 4
+; IRC-NEXT:    store i32 [[V]], ptr [[OUT2]], align 4
+; IRC-NEXT:    ret void
+;
+; IRO-LABEL: define ptx_kernel void @ptr_generic(
+; IRO-SAME: ptr [[OUT:%.*]], ptr [[IN:%.*]]) {
+; IRO-NEXT:    [[V:%.*]] = load i32, ptr [[IN]], align 4
+; IRO-NEXT:    store i32 [[V]], ptr [[OUT]], align 4
+; IRO-NEXT:    ret void
+;
+; PTXC-LABEL: ptr_generic(
+; PTXC:       {
+; PTXC-NEXT:    .reg .b32 %r<2>;
+; PTXC-NEXT:    .reg .b64 %rd<5>;
+; PTXC-EMPTY:
+; PTXC-NEXT:  // %bb.0:
+; PTXC-NEXT:    ld.param.u64 %rd1, [ptr_generic_param_0];
+; PTXC-NEXT:    ld.param.u64 %rd2, [ptr_generic_param_1];
+; PTXC-NEXT:    cvta.to.global.u64 %rd3, %rd2;
+; PTXC-NEXT:    cvta.to.global.u64 %rd4, %rd1;
+; PTXC-NEXT:    ld.global.u32 %r1, [%rd3];
+; PTXC-NEXT:    st.global.u32 [%rd4], %r1;
+; PTXC-NEXT:    ret;
+;
+; PTXO-LABEL: ptr_generic(
+; PTXO:       {
+; PTXO-NEXT:    .reg .b32 %r<2>;
+; PTXO-NEXT:    .reg .b64 %rd<3>;
+; PTXO-EMPTY:
+; PTXO-NEXT:  // %bb.0:
+; PTXO-NEXT:    ld.param.u64 %rd1, [ptr_generic_param_0];
+; PTXO-NEXT:    ld.param.u64 %rd2, [ptr_generic_param_1];
+; PTXO-NEXT:    ld.u32 %r1, [%rd2];
+; PTXO-NEXT:    st.u32 [%rd1], %r1;
+; PTXO-NEXT:    ret;
   %v = load i32, ptr  %in, align 4
   store i32 %v, ptr %out, align 4
   ret void
 }
 
-; COMMON-LABEL: ptr_nongeneric
 define ptx_kernel void @ptr_nongeneric(ptr addrspace(1) %out, ptr addrspace(3) %in) {
-; IR-NOT: addrspacecast
-; PTX-NOT: cvta.to.global
-; PTX:  ld.shared.u32
-; PTX   st.global.u32
+; IR-LABEL: define ptx_kernel void @ptr_nongeneric(
+; IR-SAME: ptr addrspace(1) [[OUT:%.*]], ptr addrspace(3) [[IN:%.*]]) {
+; IR-NEXT:    [[V:%.*]] = load i32, ptr addrspace(3) [[IN]], align 4
+; IR-NEXT:    store i32 [[V]], ptr addrspace(1) [[OUT]], align 4
+; IR-NEXT:    ret void
+;
+; PTX-LABEL: ptr_nongeneric(
+; PTX:       {
+; PTX-NEXT:    .reg .b32 %r<2>;
+; PTX-NEXT:    .reg .b64 %rd<3>;
+; PTX-EMPTY:
+; PTX-NEXT:  // %bb.0:
+; PTX-NEXT:    ld.param.u64 %rd1, [ptr_nongeneric_param_0];
+; PTX-NEXT:    ld.param.u64 %rd2, [ptr_nongeneric_param_1];
+; PTX-NEXT:    ld.shared.u32 %r1, [%rd2];
+; PTX-NEXT:    st.global.u32 [%rd1], %r1;
+; PTX-NEXT:    ret;
   %v = load i32, ptr addrspace(3) %in, align 4
   store i32 %v, ptr addrspace(1) %out, align 4
   ret void
 }
 
-; COMMON-LABEL: ptr_as_int
- define ptx_kernel void @ptr_as_int(i64 noundef %i, i32 noundef %v) {
-; IR:   [[P:%.*]] = inttoptr i64 %i to ptr
-; IRC:  [[P1:%.*]] = addrspacecast ptr [[P]] to ptr addrspace(1)
-; IRC:  addrspacecast ptr addrspace(1) [[P1]] to ptr
-; IRO-NOT: addrspacecast
-
-; PTXC-DAG:  ld.param.u64    [[I:%rd.*]], [ptr_as_int_param_0];
-; PTXC-DAG:  ld.param.u32    [[V:%r.*]], [ptr_as_int_param_1];
-; PTXC:      cvta.to.global.u64 %[[P:rd.*]], [[I]];
-; PTXC:      st.global.u32    [%[[P]]], [[V]];
-
-; PTXO-DAG:  ld.param.u64    %[[P:rd.*]], [ptr_as_int_param_0];
-; PTXO-DAG:  ld.param.u32    [[V:%r.*]], [ptr_as_int_param_1];
-; PTXO:      st.u32   [%[[P]]], [[V]];
-
+define ptx_kernel void @ptr_as_int(i64 noundef %i, i32 noundef %v) {
+; IRC-LABEL: define ptx_kernel void @ptr_as_int(
+; IRC-SAME: i64 noundef [[I:%.*]], i32 noundef [[V:%.*]]) {
+; IRC-NEXT:    [[P:%.*]] = inttoptr i64 [[I]] to ptr
+; IRC-NEXT:    [[P1:%.*]] = addrspacecast ptr [[P]] to ptr addrspace(1)
+; IRC-NEXT:    [[P2:%.*]] = addrspacecast ptr addrspace(1) [[P1]] to ptr
+; IRC-NEXT:    store i32 [[V]], ptr [[P2]], align 4
+; IRC-NEXT:    ret void
+;
+; IRO-LABEL: define ptx_kernel void @ptr_as_int(
+; IRO-SAME: i64 noundef [[I:%.*]], i32 noundef [[V:%.*]]) {
+; IRO-NEXT:    [[P:%.*]] = inttoptr i64 [[I]] to ptr
+; IRO-NEXT:    store i32 [[V]], ptr [[P]], align 4
+; IRO-NEXT:    ret void
+;
+; PTXC-LABEL: ptr_as_int(
+; PTXC:       {
+; PTXC-NEXT:    .reg .b32 %r<2>;
+; PTXC-NEXT:    .reg .b64 %rd<3>;
+; PTXC-EMPTY:
+; PTXC-NEXT:  // %bb.0:
+; PTXC-NEXT:    ld.param.u64 %rd1, [ptr_as_int_param_0];
+; PTXC-NEXT:    ld.param.u32 %r1, [ptr_as_int_param_1];
+; PTXC-NEXT:    cvta.to.global.u64 %rd2, %rd1;
+; PTXC-NEXT:    st.global.u32 [%rd2], %r1;
+; PTXC-NEXT:    ret;
+;
+; PTXO-LABEL: ptr_as_int(
+; PTXO:       {
+; PTXO-NEXT:    .reg .b32 %r<2>;
+; PTXO-NEXT:    .reg .b64 %rd<2>;
+; PTXO-EMPTY:
+; PTXO-NEXT:  // %bb.0:
+; PTXO-NEXT:    ld.param.u64 %rd1, [ptr_as_int_param_0];
+; PTXO-NEXT:    ld.param.u32 %r1, [ptr_as_int_param_1];
+; PTXO-NEXT:    st.u32 [%rd1], %r1;
+; PTXO-NEXT:    ret;
   %p = inttoptr i64 %i to ptr
   store i32 %v, ptr %p, align 4
   ret void
@@ -121,29 +238,52 @@ define ptx_kernel void @ptr_nongeneric(ptr addrspace(1) %out, ptr addrspace(3) %
 
 %struct.S = type { i64 }
 
-; COMMON-LABEL: ptr_as_int_aggr
 define ptx_kernel void @ptr_as_int_aggr(ptr nocapture noundef readonly byval(%struct.S) align 8 %s, i32 noundef %v) {
-; IR:   [[S:%.*]] = addrspacecast ptr %s to ptr addrspace(101)
-; IR:   [[I:%.*]] = load i64, ptr addrspace(101) [[S]], align 8
-; IR:   [[P0:%.*]] = inttoptr i64 [[I]] to ptr
-; IRC:  [[P1:%.*]] = addrspacecast ptr [[P]] to ptr addrspace(1)
-; IRC:  [[P:%.*]] = addrspacecast ptr addrspace(1) [[P1]] to ptr
-; IRO-NOT: addrspacecast
-
-; PTXC-DAG:  ld.param.u64    [[I:%rd.*]], [ptr_as_int_aggr_param_0];
-; PTXC-DAG:  ld.param.u32    [[V:%r.*]], [ptr_as_int_aggr_param_1];
-; PTXC:      cvta.to.global.u64 %[[P:rd.*]], [[I]];
-; PTXC:      st.global.u32    [%[[P]]], [[V]];
-
-; PTXO-DAG:  ld.param.u64    %[[P:rd.*]], [ptr_as_int_aggr_param_0];
-; PTXO-DAG:  ld.param.u32    [[V:%r.*]], [ptr_as_int_aggr_param_1];
-; PTXO:      st.u32   [%[[P]]], [[V]];
+; IRC-LABEL: define ptx_kernel void @ptr_as_int_aggr(
+; IRC-SAME: ptr noundef readonly byval([[STRUCT_S:%.*]]) align 8 captures(none) [[S:%.*]], i32 noundef [[V:%.*]]) {
+; IRC-NEXT:    [[S3:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
+; IRC-NEXT:    [[I:%.*]] = load i64, ptr addrspace(101) [[S3]], align 8
+; IRC-NEXT:    [[P:%.*]] = inttoptr i64 [[I]] to ptr
+; IRC-NEXT:    [[P1:%.*]] = addrspacecast ptr [[P]] to ptr addrspace(1)
+; IRC-NEXT:    [[P2:%.*]] = addrspacecast ptr addrspace(1) [[P1]] to ptr
+; IRC-NEXT:    store i32 [[V]], ptr [[P2]], align 4
+; IRC-NEXT:    ret void
+;
+; IRO-LABEL: define ptx_kernel void @ptr_as_int_aggr(
+; IRO-SAME: ptr noundef readonly byval([[STRUCT_S:%.*]]) align 8 captures(none) [[S:%.*]], i32 noundef [[V:%.*]]) {
+; IRO-NEXT:    [[S1:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
+; IRO-NEXT:    [[I:%.*]] = load i64, ptr addrspace(101) [[S1]], align 8
+; IRO-NEXT:    [[P:%.*]] = inttoptr i64 [[I]] to ptr
+; IRO-NEXT:    store i32 [[V]], ptr [[P]], align 4
+; IRO-NEXT:    ret void
+;
+; PTXC-LABEL: ptr_as_int_aggr(
+; PTXC:       {
+; PTXC-NEXT:    .reg .b32 %r<2>;
+; PTXC-NEXT:    .reg .b64 %rd<3>;
+; PTXC-EMPTY:
+; PTXC-NEXT:  // %bb.0:
+; PTXC-NEXT:    ld.param.u32 %r1, [ptr_as_int_aggr_param_1];
+; PTXC-NEXT:    ld.param.u64 %rd1, [ptr_as_int_aggr_param_0];
+; PTXC-NEXT:    cvta.to.global.u64 %rd2, %rd1;
+; PTXC-NEXT:    st.global.u32 [%rd2], %r1;
+; PTXC-NEXT:    ret;
+;
+; PTXO-LABEL: ptr_as_int_aggr(
+; PTXO:       {
+; PTXO-NEXT:    .reg .b32 %r<2>;
+; PTXO-NEXT:    .reg .b64 %rd<2>;
+; PTXO-EMPTY:
+; PTXO-NEXT:  // %bb.0:
+; PTXO-NEXT:    ld.param.u32 %r1, [ptr_as_int_aggr_param_1];
+; PTXO-NEXT:    ld.param.u64 %rd1, [ptr_as_int_aggr_param_0];
+; PTXO-NEXT:    st.u32 [%rd1], %r1;
+; PTXO-NEXT:    ret;
   %i = load i64, ptr %s, align 8
   %p = inttoptr i64 %i to ptr
   store i32 %v, ptr %p, align 4
   ret void
 }
 
-
 ; Function Attrs: convergent nounwind
 declare dso_local ptr @escape(ptr) local_unnamed_addr

>From 1a4100e15cd9d16190deb44025c8651972bf521e Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Fri, 28 Feb 2025 05:10:39 +0000
Subject: [PATCH 2/3] [NVPTX] Improve byval device parameter lowering

---
 llvm/lib/Target/NVPTX/CMakeLists.txt          |   1 +
 llvm/lib/Target/NVPTX/NVPTX.h                 |   1 +
 llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp  | 169 ++++++++++++++++++
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp   |   4 +-
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   |  16 +-
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td       |  21 +--
 llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp      |  64 +++----
 llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp  |   3 +
 llvm/test/CodeGen/NVPTX/forward-ld-param.ll   | 142 +++++++++++++++
 llvm/test/CodeGen/NVPTX/i128-array.ll         |  15 +-
 .../CodeGen/NVPTX/lower-args-gridconstant.ll  |  50 +++---
 llvm/test/CodeGen/NVPTX/lower-args.ll         |  53 ++----
 llvm/test/CodeGen/NVPTX/variadics-backend.ll  |  20 +--
 .../Inputs/nvptx-basic.ll.expected            |  32 ++--
 14 files changed, 434 insertions(+), 157 deletions(-)
 create mode 100644 llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
 create mode 100644 llvm/test/CodeGen/NVPTX/forward-ld-param.ll

diff --git a/llvm/lib/Target/NVPTX/CMakeLists.txt b/llvm/lib/Target/NVPTX/CMakeLists.txt
index dfbda84534732..1cffde138eab7 100644
--- a/llvm/lib/Target/NVPTX/CMakeLists.txt
+++ b/llvm/lib/Target/NVPTX/CMakeLists.txt
@@ -16,6 +16,7 @@ set(NVPTXCodeGen_sources
   NVPTXAtomicLower.cpp
   NVPTXAsmPrinter.cpp
   NVPTXAssignValidGlobalNames.cpp
+  NVPTXForwardParams.cpp
   NVPTXFrameLowering.cpp
   NVPTXGenericToNVVM.cpp
   NVPTXISelDAGToDAG.cpp
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index ca915cd3f3732..62f51861ac55a 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -52,6 +52,7 @@ FunctionPass *createNVPTXLowerUnreachablePass(bool TrapUnreachable,
                                               bool NoTrapAfterNoreturn);
 MachineFunctionPass *createNVPTXPeephole();
 MachineFunctionPass *createNVPTXProxyRegErasurePass();
+MachineFunctionPass *createNVPTXForwardParamsPass();
 
 struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
   PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
diff --git a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
new file mode 100644
index 0000000000000..47d44b985363d
--- /dev/null
+++ b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
@@ -0,0 +1,169 @@
+//- NVPTXForwardParams.cpp - NVPTX Forward Device Params Removing Local Copy -//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// PTX supports 2 methods of accessing device function parameters:
+//
+//   - "simple" case: If a parameters is only loaded, and all loads can address
+//     the parameter via a constant offset, then the parameter may be loaded via
+//     the ".param" address space. This case is not possible if the parameters
+//     is stored to or has it's address taken. This method is preferable when
+//     possible. Ex:
+//
+//            ld.param.u32    %r1, [foo_param_1];
+//            ld.param.u32    %r2, [foo_param_1+4];
+//
+//   - "move param" case: For more complex cases the address of the param may be
+//     placed in a register via a "mov" instruction. This "mov" also implicitly
+//     moves the param to the ".local" address space and allows for it to be
+//     written to. This essentially defers the responsibilty of the byval copy
+//     to the PTX calling convention.
+//
+//            mov.b64         %rd1, foo_param_0;
+//            st.local.u32    [%rd1], 42;
+//            add.u64         %rd3, %rd1, %rd2;
+//            ld.local.u32    %r2, [%rd3];
+//
+// In NVPTXLowerArgs and SelectionDAG, we pessimistically assume that all
+// parameters will use the "move param" case and the local address space. This
+// pass is responsible for switching to the "simple" case when possible, as it
+// is more efficient.
+//
+// We do this by simply traversing uses of the param "mov" instructions an
+// trivially checking if they are all loads.
+//
+//===----------------------------------------------------------------------===//
+
+#include "NVPTX.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/MachineOperand.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
+#include "llvm/Support/ErrorHandling.h"
+
+using namespace llvm;
+
+static bool traverseMoveUse(MachineInstr &U, const MachineRegisterInfo &MRI,
+                            SmallVectorImpl<MachineInstr *> &RemoveList,
+                            SmallVectorImpl<MachineInstr *> &LoadInsts) {
+  switch (U.getOpcode()) {
+  case NVPTX::LD_f32:
+  case NVPTX::LD_f64:
+  case NVPTX::LD_i16:
+  case NVPTX::LD_i32:
+  case NVPTX::LD_i64:
+  case NVPTX::LD_i8:
+  case NVPTX::LDV_f32_v2:
+  case NVPTX::LDV_f32_v4:
+  case NVPTX::LDV_f64_v2:
+  case NVPTX::LDV_f64_v4:
+  case NVPTX::LDV_i16_v2:
+  case NVPTX::LDV_i16_v4:
+  case NVPTX::LDV_i32_v2:
+  case NVPTX::LDV_i32_v4:
+  case NVPTX::LDV_i64_v2:
+  case NVPTX::LDV_i64_v4:
+  case NVPTX::LDV_i8_v2:
+  case NVPTX::LDV_i8_v4: {
+    LoadInsts.push_back(&U);
+    return true;
+  }
+  case NVPTX::cvta_local:
+  case NVPTX::cvta_local_64:
+  case NVPTX::cvta_to_local:
+  case NVPTX::cvta_to_local_64: {
+    for (auto &U2 : MRI.use_instructions(U.operands_begin()->getReg()))
+      if (!traverseMoveUse(U2, MRI, RemoveList, LoadInsts))
+        return false;
+
+    RemoveList.push_back(&U);
+    return true;
+  }
+  default:
+    return false;
+  }
+}
+
+static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
+                          SmallVectorImpl<MachineInstr *> &RemoveList) {
+  SmallVector<MachineInstr *, 16> MaybeRemoveList;
+  SmallVector<MachineInstr *, 16> LoadInsts;
+
+  for (auto &U : MRI.use_instructions(Mov.operands_begin()->getReg()))
+    if (!traverseMoveUse(U, MRI, MaybeRemoveList, LoadInsts))
+      return false;
+
+  RemoveList.append(MaybeRemoveList);
+  RemoveList.push_back(&Mov);
+
+  const MachineOperand *ParamSymbol = Mov.uses().begin();
+  assert(ParamSymbol->isSymbol());
+
+  constexpr unsigned LDInstBasePtrOpIdx = 6;
+  constexpr unsigned LDInstAddrSpaceOpIdx = 2;
+  for (auto *LI : LoadInsts) {
+    (LI->uses().begin() + LDInstBasePtrOpIdx)
+        ->ChangeToES(ParamSymbol->getSymbolName());
+    (LI->uses().begin() + LDInstAddrSpaceOpIdx)
+        ->ChangeToImmediate(NVPTX::AddressSpace::Param);
+  }
+  return true;
+}
+
+static bool forwardDeviceParams(MachineFunction &MF) {
+  const auto &MRI = MF.getRegInfo();
+
+  bool Changed = false;
+  SmallVector<MachineInstr *, 16> RemoveList;
+  for (auto &MI : make_early_inc_range(*MF.begin()))
+    if (MI.getOpcode() == NVPTX::MOV32_PARAM ||
+        MI.getOpcode() == NVPTX::MOV64_PARAM)
+      Changed |= eliminateMove(MI, MRI, RemoveList);
+
+  for (auto *MI : RemoveList)
+    MI->eraseFromParent();
+
+  return Changed;
+}
+
+/// ----------------------------------------------------------------------------
+///                       Pass (Manager) Boilerplate
+/// ----------------------------------------------------------------------------
+
+namespace llvm {
+void initializeNVPTXForwardParamsPassPass(PassRegistry &);
+} // namespace llvm
+
+namespace {
+struct NVPTXForwardParamsPass : public MachineFunctionPass {
+  static char ID;
+  NVPTXForwardParamsPass() : MachineFunctionPass(ID) {
+    initializeNVPTXForwardParamsPassPass(*PassRegistry::getPassRegistry());
+  }
+
+  bool runOnMachineFunction(MachineFunction &MF) override;
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    MachineFunctionPass::getAnalysisUsage(AU);
+  }
+};
+} // namespace
+
+char NVPTXForwardParamsPass::ID = 0;
+
+INITIALIZE_PASS(NVPTXForwardParamsPass, "nvptx-forward-params",
+                "NVPTX Forward Params", false, false)
+
+bool NVPTXForwardParamsPass::runOnMachineFunction(MachineFunction &MF) {
+  return forwardDeviceParams(MF);
+}
+
+MachineFunctionPass *llvm::createNVPTXForwardParamsPass() {
+  return new NVPTXForwardParamsPass();
+}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 8a5cdd7412bf3..0461ed4712221 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -2197,11 +2197,11 @@ static SDValue selectBaseADDR(SDValue N, SelectionDAG *DAG) {
   if (N.getOpcode() == NVPTXISD::Wrapper)
     return N.getOperand(0);
 
-  // addrspacecast(MoveParam(arg_symbol) to addrspace(PARAM)) -> arg_symbol
+  // addrspacecast(Wrapper(arg_symbol) to addrspace(PARAM)) -> arg_symbol
   if (AddrSpaceCastSDNode *CastN = dyn_cast<AddrSpaceCastSDNode>(N))
     if (CastN->getSrcAddressSpace() == ADDRESS_SPACE_GENERIC &&
         CastN->getDestAddressSpace() == ADDRESS_SPACE_PARAM &&
-        CastN->getOperand(0).getOpcode() == NVPTXISD::MoveParam)
+        CastN->getOperand(0).getOpcode() == NVPTXISD::Wrapper)
       return selectBaseADDR(CastN->getOperand(0).getOperand(0), DAG);
 
   if (auto *FIN = dyn_cast<FrameIndexSDNode>(N))
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index f5760cdb45306..3e755c25fd91a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3376,10 +3376,18 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
     assert(ObjectVT == Ins[InsIdx].VT &&
            "Ins type did not match function type");
     SDValue Arg = getParamSymbol(DAG, i, PtrVT);
-    SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
-    if (p.getNode())
-      p.getNode()->setIROrder(i + 1);
-    InVals.push_back(p);
+
+    SDValue P;
+    if (isKernelFunction(*F)) {
+      P = DAG.getNode(NVPTXISD::Wrapper, dl, ObjectVT, Arg);
+      P.getNode()->setIROrder(i + 1);
+    } else {
+      P = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
+      P.getNode()->setIROrder(i + 1);
+      P = DAG.getAddrSpaceCast(dl, ObjectVT, P, ADDRESS_SPACE_LOCAL,
+                               ADDRESS_SPACE_GENERIC);
+    }
+    InVals.push_back(P);
   }
 
   if (!OutChains.empty())
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 36a0a06bdb8aa..6edb0998760b8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -2324,7 +2324,7 @@ def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>;
 def SDTCallArgMarkProfile : SDTypeProfile<0, 0, []>;
 def SDTCallVoidProfile : SDTypeProfile<0, 1, []>;
 def SDTCallValProfile : SDTypeProfile<1, 0, []>;
-def SDTMoveParamProfile : SDTypeProfile<1, 1, []>;
+def SDTMoveParamProfile : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>;
 def SDTStoreRetvalProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>;
 def SDTStoreRetvalV2Profile : SDTypeProfile<0, 3, [SDTCisInt<0>]>;
 def SDTStoreRetvalV4Profile : SDTypeProfile<0, 5, [SDTCisInt<0>]>;
@@ -2688,29 +2688,14 @@ def DeclareScalarRegInst :
             ".reg .b$size param$a;",
             [(DeclareScalarParam (i32 imm:$a), (i32 imm:$size), (i32 1))]>;
 
-class MoveParamInst<ValueType T, NVPTXRegClass regclass, string asmstr> :
-  NVPTXInst<(outs regclass:$dst), (ins regclass:$src),
-            !strconcat("mov", asmstr, " \t$dst, $src;"),
-            [(set T:$dst, (MoveParam T:$src))]>;
-
 class MoveParamSymbolInst<NVPTXRegClass regclass, Operand srcty, ValueType vt,
                           string asmstr> :
   NVPTXInst<(outs regclass:$dst), (ins srcty:$src),
             !strconcat("mov", asmstr, " \t$dst, $src;"),
             [(set vt:$dst, (MoveParam texternalsym:$src))]>;
 
-def MoveParamI64 : MoveParamInst<i64, Int64Regs, ".b64">;
-def MoveParamI32 : MoveParamInst<i32, Int32Regs, ".b32">;
-
-def MoveParamSymbolI64 : MoveParamSymbolInst<Int64Regs, i64imm, i64, ".b64">;
-def MoveParamSymbolI32 : MoveParamSymbolInst<Int32Regs, i32imm, i32, ".b32">;
-
-def MoveParamI16 :
-  NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
-            "cvt.u16.u32 \t$dst, $src;", // ??? Why cvt.u16.u32 ?
-            [(set i16:$dst, (MoveParam i16:$src))]>;
-def MoveParamF64 : MoveParamInst<f64, Float64Regs, ".f64">;
-def MoveParamF32 : MoveParamInst<f32, Float32Regs, ".f32">;
+def MOV64_PARAM : MoveParamSymbolInst<Int64Regs, i64imm, i64, ".b64">;
+def MOV32_PARAM : MoveParamSymbolInst<Int32Regs, i32imm, i32, ".b32">;
 
 class PseudoUseParamInst<NVPTXRegClass regclass, ValueType vt> :
   NVPTXInst<(outs), (ins regclass:$src),
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index c763b54c8dbfe..5161a682fb01a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -153,6 +153,7 @@
 #include "llvm/Pass.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/NVPTXAddrSpace.h"
 #include <numeric>
 #include <queue>
 
@@ -373,19 +374,19 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
   Type *StructType = Arg->getParamByValType();
   const DataLayout &DL = Func->getDataLayout();
 
-  uint64_t NewArgAlign =
-      TLI->getFunctionParamOptimizedAlign(Func, StructType, DL).value();
-  uint64_t CurArgAlign =
-      Arg->getAttribute(Attribute::Alignment).getValueAsInt();
+  const Align NewArgAlign =
+      TLI->getFunctionParamOptimizedAlign(Func, StructType, DL);
+  const Align CurArgAlign = Arg->getParamAlign().valueOrOne();
 
   if (CurArgAlign >= NewArgAlign)
     return;
 
-  LLVM_DEBUG(dbgs() << "Try to use alignment " << NewArgAlign << " instead of "
-                    << CurArgAlign << " for " << *Arg << '\n');
+  LLVM_DEBUG(dbgs() << "Try to use alignment " << NewArgAlign.value()
+                    << " instead of " << CurArgAlign.value() << " for " << *Arg
+                    << '\n');
 
   auto NewAlignAttr =
-      Attribute::get(Func->getContext(), Attribute::Alignment, NewArgAlign);
+      Attribute::getWithAlignment(Func->getContext(), NewArgAlign);
   Arg->removeAttr(Attribute::Alignment);
   Arg->addAttr(NewAlignAttr);
 
@@ -402,7 +403,6 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
   SmallVector<Load> Loads;
   std::queue<LoadContext> Worklist;
   Worklist.push({ArgInParamAS, 0});
-  bool IsGridConstant = isParamGridConstant(*Arg);
 
   while (!Worklist.empty()) {
     LoadContext Ctx = Worklist.front();
@@ -411,15 +411,9 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
     for (User *CurUser : Ctx.InitialVal->users()) {
       if (auto *I = dyn_cast<LoadInst>(CurUser)) {
         Loads.push_back({I, Ctx.Offset});
-        continue;
-      }
-
-      if (auto *I = dyn_cast<BitCastInst>(CurUser)) {
-        Worklist.push({I, Ctx.Offset});
-        continue;
-      }
-
-      if (auto *I = dyn_cast<GetElementPtrInst>(CurUser)) {
+      } else if (isa<BitCastInst>(CurUser) || isa<AddrSpaceCastInst>(CurUser)) {
+        Worklist.push({cast<Instruction>(CurUser), Ctx.Offset});
+      } else if (auto *I = dyn_cast<GetElementPtrInst>(CurUser)) {
         APInt OffsetAccumulated =
             APInt::getZero(DL.getIndexSizeInBits(ADDRESS_SPACE_PARAM));
 
@@ -431,26 +425,13 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
         assert(Offset != OffsetLimit && "Expect Offset less than UINT64_MAX");
 
         Worklist.push({I, Ctx.Offset + Offset});
-        continue;
       }
-
-      if (isa<MemTransferInst>(CurUser))
-        continue;
-
-      // supported for grid_constant
-      if (IsGridConstant &&
-          (isa<CallInst>(CurUser) || isa<StoreInst>(CurUser) ||
-           isa<PtrToIntInst>(CurUser)))
-        continue;
-
-      llvm_unreachable("All users must be one of: load, "
-                       "bitcast, getelementptr, call, store, ptrtoint");
     }
   }
 
   for (Load &CurLoad : Loads) {
-    Align NewLoadAlign(std::gcd(NewArgAlign, CurLoad.Offset));
-    Align CurLoadAlign(CurLoad.Inst->getAlign());
+    Align NewLoadAlign(std::gcd(NewArgAlign.value(), CurLoad.Offset));
+    Align CurLoadAlign = CurLoad.Inst->getAlign();
     CurLoad.Inst->setAlignment(std::max(NewLoadAlign, CurLoadAlign));
   }
 }
@@ -641,7 +622,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
     copyByValParam(*Func, *Arg);
 }
 
-void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
+static void markPointerAsAS(Value *Ptr, const unsigned AS) {
   if (Ptr->getType()->getPointerAddressSpace() != ADDRESS_SPACE_GENERIC)
     return;
 
@@ -658,8 +639,7 @@ void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
   }
 
   Instruction *PtrInGlobal = new AddrSpaceCastInst(
-      Ptr, PointerType::get(Ptr->getContext(), ADDRESS_SPACE_GLOBAL),
-      Ptr->getName(), InsertPt);
+      Ptr, PointerType::get(Ptr->getContext(), AS), Ptr->getName(), InsertPt);
   Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(),
                                               Ptr->getName(), InsertPt);
   // Replace with PtrInGeneric all uses of Ptr except PtrInGlobal.
@@ -667,6 +647,10 @@ void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
   PtrInGlobal->setOperand(0, Ptr);
 }
 
+void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
+  markPointerAsAS(Ptr, ADDRESS_SPACE_GLOBAL);
+}
+
 // =============================================================================
 // Main function for this pass.
 // =============================================================================
@@ -724,9 +708,15 @@ bool NVPTXLowerArgs::runOnKernelFunction(const NVPTXTargetMachine &TM,
 bool NVPTXLowerArgs::runOnDeviceFunction(const NVPTXTargetMachine &TM,
                                          Function &F) {
   LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n");
+
+  const auto *TLI =
+      cast<NVPTXTargetLowering>(TM.getSubtargetImpl()->getTargetLowering());
+
   for (Argument &Arg : F.args())
-    if (Arg.getType()->isPointerTy() && Arg.hasByValAttr())
-      handleByValParam(TM, &Arg);
+    if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) {
+      markPointerAsAS(&Arg, ADDRESS_SPACE_LOCAL);
+      adjustByValArgAlignment(&Arg, &Arg, TLI);
+    }
   return true;
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index f2afa6fc20bfa..229fecf2d3b10 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -100,6 +100,7 @@ void initializeNVPTXLowerUnreachablePass(PassRegistry &);
 void initializeNVPTXCtorDtorLoweringLegacyPass(PassRegistry &);
 void initializeNVPTXLowerArgsPass(PassRegistry &);
 void initializeNVPTXProxyRegErasurePass(PassRegistry &);
+void initializeNVPTXForwardParamsPassPass(PassRegistry &);
 void initializeNVVMIntrRangePass(PassRegistry &);
 void initializeNVVMReflectPass(PassRegistry &);
 void initializeNVPTXAAWrapperPassPass(PassRegistry &);
@@ -127,6 +128,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
   initializeNVPTXCtorDtorLoweringLegacyPass(PR);
   initializeNVPTXLowerAggrCopiesPass(PR);
   initializeNVPTXProxyRegErasurePass(PR);
+  initializeNVPTXForwardParamsPassPass(PR);
   initializeNVPTXDAGToDAGISelLegacyPass(PR);
   initializeNVPTXAAWrapperPassPass(PR);
   initializeNVPTXExternalAAWrapperPass(PR);
@@ -429,6 +431,7 @@ bool NVPTXPassConfig::addInstSelector() {
 }
 
 void NVPTXPassConfig::addPreRegAlloc() {
+  addPass(createNVPTXForwardParamsPass());
   // Remove Proxy Register pseudo instructions used to keep `callseq_end` alive.
   addPass(createNVPTXProxyRegErasurePass());
 }
diff --git a/llvm/test/CodeGen/NVPTX/forward-ld-param.ll b/llvm/test/CodeGen/NVPTX/forward-ld-param.ll
new file mode 100644
index 0000000000000..c4e56d197edc0
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/forward-ld-param.ll
@@ -0,0 +1,142 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+
+target triple = "nvptx64-nvidia-cuda"
+
+define i32 @test_ld_param_const(ptr byval(i32) %a) {
+; CHECK-LABEL: test_ld_param_const(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b64 %rd<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_ld_param_const_param_0+4];
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %p2 = getelementptr i32, ptr %a, i32 1
+  %ld = load i32, ptr %p2
+  ret i32 %ld
+}
+
+define i32 @test_ld_param_non_const(ptr byval([10 x i32]) %a, i32 %b) {
+; CHECK-LABEL: test_ld_param_non_const(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b64 %rd<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    mov.b64 %rd1, test_ld_param_non_const_param_0;
+; CHECK-NEXT:    cvta.local.u64 %rd2, %rd1;
+; CHECK-NEXT:    cvta.to.local.u64 %rd3, %rd2;
+; CHECK-NEXT:    ld.param.s32 %rd4, [test_ld_param_non_const_param_1];
+; CHECK-NEXT:    add.s64 %rd5, %rd3, %rd4;
+; CHECK-NEXT:    ld.local.u32 %r1, [%rd5];
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %p2 = getelementptr i8, ptr %a, i32 %b
+  %ld = load i32, ptr %p2
+  ret i32 %ld
+}
+
+declare void @escape(ptr)
+declare void @byval_user(ptr byval(i32))
+
+define void @test_ld_param_escaping(ptr byval(i32) %a) {
+; CHECK-LABEL: test_ld_param_escaping(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    mov.b64 %rd1, test_ld_param_escaping_param_0;
+; CHECK-NEXT:    cvta.local.u64 %rd2, %rd1;
+; CHECK-NEXT:    { // callseq 0, 0
+; CHECK-NEXT:    .param .b64 param0;
+; CHECK-NEXT:    st.param.b64 [param0], %rd2;
+; CHECK-NEXT:    call.uni
+; CHECK-NEXT:    escape,
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0
+; CHECK-NEXT:    );
+; CHECK-NEXT:    } // callseq 0
+; CHECK-NEXT:    ret;
+  call void @escape(ptr %a)
+  ret void
+}
+
+define void @test_ld_param_byval(ptr byval(i32) %a) {
+; CHECK-LABEL: test_ld_param_byval(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_ld_param_byval_param_0];
+; CHECK-NEXT:    { // callseq 1, 0
+; CHECK-NEXT:    .param .align 4 .b8 param0[4];
+; CHECK-NEXT:    st.param.b32 [param0], %r1;
+; CHECK-NEXT:    call.uni
+; CHECK-NEXT:    byval_user,
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0
+; CHECK-NEXT:    );
+; CHECK-NEXT:    } // callseq 1
+; CHECK-NEXT:    ret;
+  call void @byval_user(ptr %a)
+  ret void
+}
+
+define i32 @test_modify_param(ptr byval([10 x i32]) %a, i32 %b, i32 %c ) {
+; CHECK-LABEL: test_modify_param(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-NEXT:    .reg .b64 %rd<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    mov.b64 %rd1, test_modify_param_param_0;
+; CHECK-NEXT:    cvta.local.u64 %rd2, %rd1;
+; CHECK-NEXT:    cvta.to.local.u64 %rd3, %rd2;
+; CHECK-NEXT:    ld.param.u32 %r1, [test_modify_param_param_1];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_modify_param_param_2];
+; CHECK-NEXT:    st.local.u32 [%rd3+2], %r1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT:    ret;
+  %p2 = getelementptr i8, ptr %a, i32 2
+  store volatile i32 %b, ptr %p2
+  ret i32 %c
+}
+
+define i32 @test_multi_block(ptr byval([10 x i32]) %a, i1 %p) {
+; CHECK-LABEL: test_multi_block(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<3>;
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u8 %rs1, [test_multi_block_param_1];
+; CHECK-NEXT:    and.b16 %rs2, %rs1, 1;
+; CHECK-NEXT:    setp.eq.b16 %p1, %rs2, 1;
+; CHECK-NEXT:    not.pred %p2, %p1;
+; CHECK-NEXT:    @%p2 bra $L__BB5_2;
+; CHECK-NEXT:  // %bb.1: // %if
+; CHECK-NEXT:    ld.param.u32 %r4, [test_multi_block_param_0+4];
+; CHECK-NEXT:    bra.uni $L__BB5_3;
+; CHECK-NEXT:  $L__BB5_2: // %else
+; CHECK-NEXT:    ld.param.u32 %r4, [test_multi_block_param_0+8];
+; CHECK-NEXT:  $L__BB5_3: // %end
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT:    ret;
+  br i1 %p, label %if, label %else
+if:
+  %p2 = getelementptr i8, ptr %a, i32 4
+  %v2  = load i32, ptr %p2
+  br label %end
+else:
+  %p3 = getelementptr i8, ptr %a, i32 8
+  %v3 = load i32, ptr %p3
+  br label %end
+end:
+  %v = phi i32 [ %v2, %if ], [ %v3, %else ]
+  ret i32 %v
+}
diff --git a/llvm/test/CodeGen/NVPTX/i128-array.ll b/llvm/test/CodeGen/NVPTX/i128-array.ll
index 348df8dcc7373..baa18880de840 100644
--- a/llvm/test/CodeGen/NVPTX/i128-array.ll
+++ b/llvm/test/CodeGen/NVPTX/i128-array.ll
@@ -27,16 +27,15 @@ define [2 x i128] @foo(i64 %a, i32 %b) {
 define [2 x i128] @foo2(ptr byval([2 x i128]) %a) {
 ; CHECK-LABEL: foo2(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b64 %rd<6>;
+; CHECK-NEXT:    .reg .b64 %rd<9>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    mov.b64 %rd1, foo2_param_0;
-; CHECK-NEXT:    ld.param.u64 %rd2, [foo2_param_0+8];
-; CHECK-NEXT:    ld.param.u64 %rd3, [foo2_param_0];
-; CHECK-NEXT:    ld.param.u64 %rd4, [foo2_param_0+24];
-; CHECK-NEXT:    ld.param.u64 %rd5, [foo2_param_0+16];
-; CHECK-NEXT:    st.param.v2.b64 [func_retval0], {%rd3, %rd2};
-; CHECK-NEXT:    st.param.v2.b64 [func_retval0+16], {%rd5, %rd4};
+; CHECK-NEXT:    ld.param.u64 %rd5, [foo2_param_0+8];
+; CHECK-NEXT:    ld.param.u64 %rd6, [foo2_param_0];
+; CHECK-NEXT:    ld.param.u64 %rd7, [foo2_param_0+24];
+; CHECK-NEXT:    ld.param.u64 %rd8, [foo2_param_0+16];
+; CHECK-NEXT:    st.param.v2.b64 [func_retval0], {%rd6, %rd5};
+; CHECK-NEXT:    st.param.v2.b64 [func_retval0+16], {%rd8, %rd7};
 ; CHECK-NEXT:    ret;
   %ptr0 = getelementptr [2 x i128], ptr %a, i64 0, i32 0
   %1 = load i128, i128* %ptr0
diff --git a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
index fe15be5663be1..90f9306d036cd 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
@@ -12,9 +12,8 @@ define dso_local noundef i32 @non_kernel_function(ptr nocapture noundef readonly
 ; OPT-LABEL: define dso_local noundef i32 @non_kernel_function(
 ; OPT-SAME: ptr noundef readonly byval([[STRUCT_UINT4:%.*]]) align 16 captures(none) [[A:%.*]], i1 noundef zeroext [[B:%.*]], i32 noundef [[C:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 ; OPT-NEXT:  [[ENTRY:.*:]]
-; OPT-NEXT:    [[A1:%.*]] = alloca [[STRUCT_UINT4]], align 16
-; OPT-NEXT:    [[A2:%.*]] = addrspacecast ptr [[A]] to ptr addrspace(101)
-; OPT-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 16 [[A1]], ptr addrspace(101) align 16 [[A2]], i64 16, i1 false)
+; OPT-NEXT:    [[A2:%.*]] = addrspacecast ptr [[A]] to ptr addrspace(5)
+; OPT-NEXT:    [[A1:%.*]] = addrspacecast ptr addrspace(5) [[A2]] to ptr
 ; OPT-NEXT:    [[A_:%.*]] = select i1 [[B]], ptr [[A1]], ptr addrspacecast (ptr addrspace(1) @gi to ptr)
 ; OPT-NEXT:    [[IDX_EXT:%.*]] = sext i32 [[C]] to i64
 ; OPT-NEXT:    [[ADD_PTR:%.*]] = getelementptr inbounds i8, ptr [[A_]], i64 [[IDX_EXT]]
@@ -23,38 +22,29 @@ define dso_local noundef i32 @non_kernel_function(ptr nocapture noundef readonly
 ;
 ; PTX-LABEL: non_kernel_function(
 ; PTX:       {
-; PTX-NEXT:    .local .align 16 .b8 __local_depot0[16];
-; PTX-NEXT:    .reg .b64 %SP;
-; PTX-NEXT:    .reg .b64 %SPL;
 ; PTX-NEXT:    .reg .pred %p<2>;
 ; PTX-NEXT:    .reg .b16 %rs<3>;
 ; PTX-NEXT:    .reg .b32 %r<11>;
-; PTX-NEXT:    .reg .b64 %rd<10>;
+; PTX-NEXT:    .reg .b64 %rd<8>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0: // %entry
-; PTX-NEXT:    mov.u64 %SPL, __local_depot0;
-; PTX-NEXT:    cvta.local.u64 %SP, %SPL;
+; PTX-NEXT:    mov.b64 %rd1, non_kernel_function_param_0;
+; PTX-NEXT:    cvta.local.u64 %rd2, %rd1;
 ; PTX-NEXT:    ld.param.u8 %rs1, [non_kernel_function_param_1];
 ; PTX-NEXT:    and.b16 %rs2, %rs1, 1;
 ; PTX-NEXT:    setp.eq.b16 %p1, %rs2, 1;
-; PTX-NEXT:    add.u64 %rd1, %SP, 0;
-; PTX-NEXT:    add.u64 %rd2, %SPL, 0;
-; PTX-NEXT:    ld.param.s32 %rd3, [non_kernel_function_param_2];
-; PTX-NEXT:    ld.param.u64 %rd4, [non_kernel_function_param_0+8];
-; PTX-NEXT:    st.local.u64 [%rd2+8], %rd4;
-; PTX-NEXT:    ld.param.u64 %rd5, [non_kernel_function_param_0];
-; PTX-NEXT:    st.local.u64 [%rd2], %rd5;
-; PTX-NEXT:    mov.u64 %rd6, gi;
-; PTX-NEXT:    cvta.global.u64 %rd7, %rd6;
-; PTX-NEXT:    selp.b64 %rd8, %rd1, %rd7, %p1;
-; PTX-NEXT:    add.s64 %rd9, %rd8, %rd3;
-; PTX-NEXT:    ld.u8 %r1, [%rd9];
-; PTX-NEXT:    ld.u8 %r2, [%rd9+1];
+; PTX-NEXT:    mov.u64 %rd3, gi;
+; PTX-NEXT:    cvta.global.u64 %rd4, %rd3;
+; PTX-NEXT:    selp.b64 %rd5, %rd2, %rd4, %p1;
+; PTX-NEXT:    ld.param.s32 %rd6, [non_kernel_function_param_2];
+; PTX-NEXT:    add.s64 %rd7, %rd5, %rd6;
+; PTX-NEXT:    ld.u8 %r1, [%rd7];
+; PTX-NEXT:    ld.u8 %r2, [%rd7+1];
 ; PTX-NEXT:    shl.b32 %r3, %r2, 8;
 ; PTX-NEXT:    or.b32 %r4, %r3, %r1;
-; PTX-NEXT:    ld.u8 %r5, [%rd9+2];
+; PTX-NEXT:    ld.u8 %r5, [%rd7+2];
 ; PTX-NEXT:    shl.b32 %r6, %r5, 16;
-; PTX-NEXT:    ld.u8 %r7, [%rd9+3];
+; PTX-NEXT:    ld.u8 %r7, [%rd7+3];
 ; PTX-NEXT:    shl.b32 %r8, %r7, 24;
 ; PTX-NEXT:    or.b32 %r9, %r8, %r6;
 ; PTX-NEXT:    or.b32 %r10, %r9, %r4;
@@ -91,6 +81,7 @@ define ptx_kernel void @grid_const_int(ptr byval(i32) align 4 %input1, i32 %inpu
 ; OPT-NEXT:    [[ADD:%.*]] = add i32 [[TMP]], [[INPUT2]]
 ; OPT-NEXT:    store i32 [[ADD]], ptr [[OUT3]], align 4
 ; OPT-NEXT:    ret void
+;
   %tmp = load i32, ptr %input1, align 4
   %add = add i32 %tmp, %input2
   store i32 %add, ptr %out
@@ -125,6 +116,7 @@ define ptx_kernel void @grid_const_struct(ptr byval(%struct.s) align 4 %input, p
 ; OPT-NEXT:    [[ADD:%.*]] = add i32 [[TMP1]], [[TMP2]]
 ; OPT-NEXT:    store i32 [[ADD]], ptr [[OUT5]], align 4
 ; OPT-NEXT:    ret void
+;
   %gep1 = getelementptr inbounds %struct.s, ptr %input, i32 0, i32 0
   %gep2 = getelementptr inbounds %struct.s, ptr %input, i32 0, i32 1
   %int1 = load i32, ptr %gep1
@@ -165,6 +157,7 @@ define ptx_kernel void @grid_const_escape(ptr byval(%struct.s) align 4 %input) {
 ; OPT-NEXT:    [[INPUT_PARAM_GEN:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[INPUT_PARAM]])
 ; OPT-NEXT:    [[CALL:%.*]] = call i32 @escape(ptr [[INPUT_PARAM_GEN]])
 ; OPT-NEXT:    ret void
+;
   %call = call i32 @escape(ptr %input)
   ret void
 }
@@ -222,6 +215,7 @@ define ptx_kernel void @multiple_grid_const_escape(ptr byval(%struct.s) align 4
 ; OPT-NEXT:    store i32 [[A]], ptr [[A_ADDR]], align 4
 ; OPT-NEXT:    [[CALL:%.*]] = call i32 @escape3(ptr [[INPUT_PARAM_GEN]], ptr [[A_ADDR]], ptr [[B_PARAM_GEN]])
 ; OPT-NEXT:    ret void
+;
   %a.addr = alloca i32, align 4
   store i32 %a, ptr %a.addr, align 4
   %call = call i32 @escape3(ptr %input, ptr %a.addr, ptr %b)
@@ -249,6 +243,7 @@ define ptx_kernel void @grid_const_memory_escape(ptr byval(%struct.s) align 4 %i
 ; OPT-NEXT:    [[INPUT1:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[INPUT_PARAM]])
 ; OPT-NEXT:    store ptr [[INPUT1]], ptr [[ADDR5]], align 8
 ; OPT-NEXT:    ret void
+;
   store ptr %input, ptr %addr, align 8
   ret void
 }
@@ -282,6 +277,7 @@ define ptx_kernel void @grid_const_inlineasm_escape(ptr byval(%struct.s) align 4
 ; OPT-NEXT:    [[TMP2:%.*]] = call i64 asm "add.s64 $0, $1, $2
 ; OPT-NEXT:    store i64 [[TMP2]], ptr [[RESULT5]], align 8
 ; OPT-NEXT:    ret void
+;
   %tmpptr1 = getelementptr inbounds %struct.s, ptr %input, i32 0, i32 0
   %tmpptr2 = getelementptr inbounds %struct.s, ptr %input, i32 0, i32 1
   %1 = call i64 asm "add.s64 $0, $1, $2;", "=l,l,l"(ptr %tmpptr1, ptr %tmpptr2) #1
@@ -330,6 +326,7 @@ define ptx_kernel void @grid_const_partial_escape(ptr byval(i32) %input, ptr %ou
 ; OPT-NEXT:    store i32 [[TWICE]], ptr [[OUTPUT5]], align 4
 ; OPT-NEXT:    [[CALL:%.*]] = call i32 @escape(ptr [[INPUT1_GEN]])
 ; OPT-NEXT:    ret void
+;
   %val = load i32, ptr %input
   %twice = add i32 %val, %val
   store i32 %twice, ptr %output
@@ -383,6 +380,7 @@ define ptx_kernel i32 @grid_const_partial_escapemem(ptr byval(%struct.s) %input,
 ; OPT-NEXT:    [[ADD:%.*]] = add i32 [[VAL1]], [[VAL2]]
 ; OPT-NEXT:    [[CALL2:%.*]] = call i32 @escape(ptr [[PTR1]])
 ; OPT-NEXT:    ret i32 [[ADD]]
+;
   %ptr1 = getelementptr inbounds %struct.s, ptr %input, i32 0, i32 0
   %val1 = load i32, ptr %ptr1
   %ptr2 = getelementptr inbounds %struct.s, ptr %input, i32 0, i32 1
@@ -435,6 +433,7 @@ define ptx_kernel void @grid_const_phi(ptr byval(%struct.s) align 4 %input1, ptr
 ; OPT-NEXT:    [[VALLOADED:%.*]] = load i32, ptr [[PTRNEW]], align 4
 ; OPT-NEXT:    store i32 [[VALLOADED]], ptr [[INOUT2]], align 4
 ; OPT-NEXT:    ret void
+;
 
   %val = load i32, ptr %inout
   %less = icmp slt i32 %val, 0
@@ -500,6 +499,7 @@ define ptx_kernel void @grid_const_phi_ngc(ptr byval(%struct.s) align 4 %input1,
 ; OPT-NEXT:    [[VALLOADED:%.*]] = load i32, ptr [[PTRNEW]], align 4
 ; OPT-NEXT:    store i32 [[VALLOADED]], ptr [[INOUT2]], align 4
 ; OPT-NEXT:    ret void
+;
   %val = load i32, ptr %inout
   %less = icmp slt i32 %val, 0
   br i1 %less, label %first, label %second
@@ -553,6 +553,7 @@ define ptx_kernel void @grid_const_select(ptr byval(i32) align 4 %input1, ptr by
 ; OPT-NEXT:    [[VALLOADED:%.*]] = load i32, ptr [[PTRNEW]], align 4
 ; OPT-NEXT:    store i32 [[VALLOADED]], ptr [[INOUT2]], align 4
 ; OPT-NEXT:    ret void
+;
   %val = load i32, ptr %inout
   %less = icmp slt i32 %val, 0
   %ptrnew = select i1 %less, ptr %input1, ptr %input2
@@ -584,6 +585,7 @@ define ptx_kernel i32 @grid_const_ptrtoint(ptr byval(i32) %input) {
 ; OPT-NEXT:    [[PTRVAL:%.*]] = ptrtoint ptr [[INPUT1]] to i32
 ; OPT-NEXT:    [[KEEPALIVE:%.*]] = add i32 [[INPUT3]], [[PTRVAL]]
 ; OPT-NEXT:    ret i32 [[KEEPALIVE]]
+;
   %val = load i32, ptr %input
   %ptrval = ptrtoint ptr %input to i32
   %keepalive = add i32 %val, %ptrval
diff --git a/llvm/test/CodeGen/NVPTX/lower-args.ll b/llvm/test/CodeGen/NVPTX/lower-args.ll
index 66bd5e52b5f11..42024a5be1c08 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args.ll
@@ -17,9 +17,8 @@ define void @load_alignment(ptr nocapture readonly byval(%class.outer) align 8 %
 ; IR-LABEL: define void @load_alignment(
 ; IR-SAME: ptr readonly byval([[CLASS_OUTER:%.*]]) align 8 captures(none) [[ARG:%.*]]) {
 ; IR-NEXT:  [[ENTRY:.*:]]
-; IR-NEXT:    [[ARG1:%.*]] = alloca [[CLASS_OUTER]], align 8
-; IR-NEXT:    [[ARG2:%.*]] = addrspacecast ptr [[ARG]] to ptr addrspace(101)
-; IR-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 8 [[ARG1]], ptr addrspace(101) align 8 [[ARG2]], i64 24, i1 false)
+; IR-NEXT:    [[ARG2:%.*]] = addrspacecast ptr [[ARG]] to ptr addrspace(5)
+; IR-NEXT:    [[ARG1:%.*]] = addrspacecast ptr addrspace(5) [[ARG2]] to ptr
 ; IR-NEXT:    [[ARG_IDX_VAL:%.*]] = load ptr, ptr [[ARG1]], align 8
 ; IR-NEXT:    [[ARG_IDX1:%.*]] = getelementptr [[CLASS_OUTER]], ptr [[ARG1]], i64 0, i32 0, i32 1
 ; IR-NEXT:    [[ARG_IDX1_VAL:%.*]] = load ptr, ptr [[ARG_IDX1]], align 8
@@ -33,27 +32,21 @@ define void @load_alignment(ptr nocapture readonly byval(%class.outer) align 8 %
 ;
 ; PTX-LABEL: load_alignment(
 ; PTX:       {
-; PTX-NEXT:    .local .align 8 .b8 __local_depot0[24];
-; PTX-NEXT:    .reg .b64 %SP;
-; PTX-NEXT:    .reg .b64 %SPL;
 ; PTX-NEXT:    .reg .b32 %r<4>;
 ; PTX-NEXT:    .reg .b64 %rd<10>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0: // %entry
-; PTX-NEXT:    mov.u64 %SPL, __local_depot0;
-; PTX-NEXT:    add.u64 %rd2, %SPL, 0;
-; PTX-NEXT:    ld.param.u64 %rd3, [load_alignment_param_0+16];
-; PTX-NEXT:    st.local.u64 [%rd2+16], %rd3;
-; PTX-NEXT:    ld.param.u64 %rd4, [load_alignment_param_0+8];
-; PTX-NEXT:    st.local.u64 [%rd2+8], %rd4;
-; PTX-NEXT:    ld.param.u64 %rd5, [load_alignment_param_0];
-; PTX-NEXT:    st.local.u64 [%rd2], %rd5;
-; PTX-NEXT:    add.s64 %rd6, %rd2, 16;
+; PTX-NEXT:    mov.b64 %rd1, load_alignment_param_0;
+; PTX-NEXT:    cvta.local.u64 %rd2, %rd1;
+; PTX-NEXT:    cvta.to.local.u64 %rd3, %rd2;
+; PTX-NEXT:    ld.local.u64 %rd4, [%rd3];
+; PTX-NEXT:    ld.local.u64 %rd5, [%rd3+8];
+; PTX-NEXT:    add.s64 %rd6, %rd3, 16;
 ; PTX-NEXT:    cvta.local.u64 %rd7, %rd6;
-; PTX-NEXT:    cvt.u32.u64 %r1, %rd3;
-; PTX-NEXT:    ld.u32 %r2, [%rd5];
+; PTX-NEXT:    ld.local.u32 %r1, [%rd3+16];
+; PTX-NEXT:    ld.u32 %r2, [%rd4];
 ; PTX-NEXT:    add.s32 %r3, %r2, %r1;
-; PTX-NEXT:    st.u32 [%rd4], %r3;
+; PTX-NEXT:    st.u32 [%rd5], %r3;
 ; PTX-NEXT:    { // callseq 0, 0
 ; PTX-NEXT:    .param .b64 param0;
 ; PTX-NEXT:    st.param.b64 [param0], %rd7;
@@ -85,37 +78,29 @@ entry:
 ; Check that nvptx-lower-args copies padding as the struct may have been a union
 define void @load_padding(ptr nocapture readonly byval(%class.padded) %arg) {
 ; IR-LABEL: define void @load_padding(
-; IR-SAME: ptr readonly byval([[CLASS_PADDED:%.*]]) captures(none) [[ARG:%.*]]) {
-; IR-NEXT:    [[ARG1:%.*]] = alloca [[CLASS_PADDED]], align 8
-; IR-NEXT:    [[ARG2:%.*]] = addrspacecast ptr [[ARG]] to ptr addrspace(101)
-; IR-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 8 [[ARG1]], ptr addrspace(101) align 8 [[ARG2]], i64 8, i1 false)
+; IR-SAME: ptr readonly byval([[CLASS_PADDED:%.*]]) align 4 captures(none) [[ARG:%.*]]) {
+; IR-NEXT:    [[ARG2:%.*]] = addrspacecast ptr [[ARG]] to ptr addrspace(5)
+; IR-NEXT:    [[ARG1:%.*]] = addrspacecast ptr addrspace(5) [[ARG2]] to ptr
 ; IR-NEXT:    [[TMP:%.*]] = call ptr @escape(ptr nonnull align 16 [[ARG1]])
 ; IR-NEXT:    ret void
 ;
 ; PTX-LABEL: load_padding(
 ; PTX:       {
-; PTX-NEXT:    .local .align 8 .b8 __local_depot1[8];
-; PTX-NEXT:    .reg .b64 %SP;
-; PTX-NEXT:    .reg .b64 %SPL;
-; PTX-NEXT:    .reg .b64 %rd<6>;
+; PTX-NEXT:    .reg .b64 %rd<5>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
-; PTX-NEXT:    mov.u64 %SPL, __local_depot1;
-; PTX-NEXT:    cvta.local.u64 %SP, %SPL;
-; PTX-NEXT:    add.u64 %rd1, %SP, 0;
-; PTX-NEXT:    add.u64 %rd2, %SPL, 0;
-; PTX-NEXT:    ld.param.u64 %rd3, [load_padding_param_0];
-; PTX-NEXT:    st.local.u64 [%rd2], %rd3;
+; PTX-NEXT:    mov.b64 %rd1, load_padding_param_0;
+; PTX-NEXT:    cvta.local.u64 %rd2, %rd1;
 ; PTX-NEXT:    { // callseq 1, 0
 ; PTX-NEXT:    .param .b64 param0;
-; PTX-NEXT:    st.param.b64 [param0], %rd1;
+; PTX-NEXT:    st.param.b64 [param0], %rd2;
 ; PTX-NEXT:    .param .b64 retval0;
 ; PTX-NEXT:    call.uni (retval0),
 ; PTX-NEXT:    escape,
 ; PTX-NEXT:    (
 ; PTX-NEXT:    param0
 ; PTX-NEXT:    );
-; PTX-NEXT:    ld.param.b64 %rd4, [retval0];
+; PTX-NEXT:    ld.param.b64 %rd3, [retval0];
 ; PTX-NEXT:    } // callseq 1
 ; PTX-NEXT:    ret;
   %tmp = call ptr @escape(ptr nonnull align 16 %arg)
diff --git a/llvm/test/CodeGen/NVPTX/variadics-backend.ll b/llvm/test/CodeGen/NVPTX/variadics-backend.ll
index 377528b94f505..eaf0ce58750b4 100644
--- a/llvm/test/CodeGen/NVPTX/variadics-backend.ll
+++ b/llvm/test/CodeGen/NVPTX/variadics-backend.ll
@@ -338,18 +338,18 @@ define dso_local i32 @variadics4(ptr noundef byval(%struct.S2) align 8 %first, .
 ; CHECK-PTX-LABEL: variadics4(
 ; CHECK-PTX:       {
 ; CHECK-PTX-NEXT:    .reg .b32 %r<2>;
-; CHECK-PTX-NEXT:    .reg .b64 %rd<9>;
+; CHECK-PTX-NEXT:    .reg .b64 %rd<12>;
 ; CHECK-PTX-EMPTY:
 ; CHECK-PTX-NEXT:  // %bb.0: // %entry
-; CHECK-PTX-NEXT:    ld.param.u64 %rd1, [variadics4_param_1];
-; CHECK-PTX-NEXT:    add.s64 %rd2, %rd1, 7;
-; CHECK-PTX-NEXT:    and.b64 %rd3, %rd2, -8;
-; CHECK-PTX-NEXT:    ld.u64 %rd4, [%rd3];
-; CHECK-PTX-NEXT:    ld.param.u64 %rd5, [variadics4_param_0];
-; CHECK-PTX-NEXT:    ld.param.u64 %rd6, [variadics4_param_0+8];
-; CHECK-PTX-NEXT:    add.s64 %rd7, %rd5, %rd6;
-; CHECK-PTX-NEXT:    add.s64 %rd8, %rd7, %rd4;
-; CHECK-PTX-NEXT:    cvt.u32.u64 %r1, %rd8;
+; CHECK-PTX-NEXT:    ld.param.u64 %rd4, [variadics4_param_1];
+; CHECK-PTX-NEXT:    add.s64 %rd5, %rd4, 7;
+; CHECK-PTX-NEXT:    and.b64 %rd6, %rd5, -8;
+; CHECK-PTX-NEXT:    ld.u64 %rd7, [%rd6];
+; CHECK-PTX-NEXT:    ld.param.u64 %rd8, [variadics4_param_0];
+; CHECK-PTX-NEXT:    ld.param.u64 %rd9, [variadics4_param_0+8];
+; CHECK-PTX-NEXT:    add.s64 %rd10, %rd8, %rd9;
+; CHECK-PTX-NEXT:    add.s64 %rd11, %rd10, %rd7;
+; CHECK-PTX-NEXT:    cvt.u32.u64 %r1, %rd11;
 ; CHECK-PTX-NEXT:    st.param.b32 [func_retval0], %r1;
 ; CHECK-PTX-NEXT:    ret;
 entry:
diff --git a/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected b/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected
index 820ade631dd64..e982758da5b06 100644
--- a/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected
+++ b/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected
@@ -6,28 +6,18 @@
 define dso_local void @caller_St8x4(ptr nocapture noundef readonly byval(%struct.St8x4) align 8 %in, ptr nocapture noundef writeonly %ret) {
 ; CHECK-LABEL: caller_St8x4(
 ; CHECK:       {
-; CHECK-NEXT:    .local .align 8 .b8 __local_depot0[32];
-; CHECK-NEXT:    .reg .b32 %SP;
-; CHECK-NEXT:    .reg .b32 %SPL;
 ; CHECK-NEXT:    .reg .b32 %r<4>;
 ; CHECK-NEXT:    .reg .b64 %rd<13>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    mov.u32 %SPL, __local_depot0;
-; CHECK-NEXT:    ld.param.u32 %r1, [caller_St8x4_param_1];
-; CHECK-NEXT:    add.u32 %r3, %SPL, 0;
-; CHECK-NEXT:    ld.param.u64 %rd1, [caller_St8x4_param_0+24];
-; CHECK-NEXT:    st.local.u64 [%r3+24], %rd1;
-; CHECK-NEXT:    ld.param.u64 %rd2, [caller_St8x4_param_0+16];
-; CHECK-NEXT:    st.local.u64 [%r3+16], %rd2;
-; CHECK-NEXT:    ld.param.u64 %rd3, [caller_St8x4_param_0+8];
-; CHECK-NEXT:    st.local.u64 [%r3+8], %rd3;
-; CHECK-NEXT:    ld.param.u64 %rd4, [caller_St8x4_param_0];
-; CHECK-NEXT:    st.local.u64 [%r3], %rd4;
+; CHECK-NEXT:    ld.param.u64 %rd1, [caller_St8x4_param_0+8];
+; CHECK-NEXT:    ld.param.u64 %rd2, [caller_St8x4_param_0];
+; CHECK-NEXT:    ld.param.u64 %rd3, [caller_St8x4_param_0+24];
+; CHECK-NEXT:    ld.param.u64 %rd4, [caller_St8x4_param_0+16];
 ; CHECK-NEXT:    { // callseq 0, 0
 ; CHECK-NEXT:    .param .align 16 .b8 param0[32];
-; CHECK-NEXT:    st.param.v2.b64 [param0], {%rd4, %rd3};
-; CHECK-NEXT:    st.param.v2.b64 [param0+16], {%rd2, %rd1};
+; CHECK-NEXT:    st.param.v2.b64 [param0], {%rd2, %rd1};
+; CHECK-NEXT:    st.param.v2.b64 [param0+16], {%rd4, %rd3};
 ; CHECK-NEXT:    .param .align 16 .b8 retval0[32];
 ; CHECK-NEXT:    call.uni (retval0),
 ; CHECK-NEXT:    callee_St8x4,
@@ -37,10 +27,11 @@ define dso_local void @caller_St8x4(ptr nocapture noundef readonly byval(%struct
 ; CHECK-NEXT:    ld.param.v2.b64 {%rd5, %rd6}, [retval0];
 ; CHECK-NEXT:    ld.param.v2.b64 {%rd7, %rd8}, [retval0+16];
 ; CHECK-NEXT:    } // callseq 0
-; CHECK-NEXT:    st.u64 [%r1], %rd5;
-; CHECK-NEXT:    st.u64 [%r1+8], %rd6;
-; CHECK-NEXT:    st.u64 [%r1+16], %rd7;
-; CHECK-NEXT:    st.u64 [%r1+24], %rd8;
+; CHECK-NEXT:    ld.param.u32 %r3, [caller_St8x4_param_1];
+; CHECK-NEXT:    st.u64 [%r3], %rd5;
+; CHECK-NEXT:    st.u64 [%r3+8], %rd6;
+; CHECK-NEXT:    st.u64 [%r3+16], %rd7;
+; CHECK-NEXT:    st.u64 [%r3+24], %rd8;
 ; CHECK-NEXT:    ret;
   %call = tail call fastcc [4 x i64] @callee_St8x4(ptr noundef nonnull byval(%struct.St8x4) align 8 %in) #2
   %.fca.0.extract = extractvalue [4 x i64] %call, 0
@@ -61,6 +52,7 @@ define internal fastcc [4 x i64] @callee_St8x4(ptr nocapture noundef readonly by
 ; CHECK-LABEL: callee_St8x4(
 ; CHECK:         // @callee_St8x4
 ; CHECK-NEXT:  {
+; CHECK-NEXT:    .reg .b32 %r<4>;
 ; CHECK-NEXT:    .reg .b64 %rd<5>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:

>From ea0147c71b8ced6a1c2321e482250a30a7abeeff Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Fri, 28 Feb 2025 20:47:10 +0000
Subject: [PATCH 3/3] rebase + address comments

---
 llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp      |  6 ++---
 llvm/test/CodeGen/NVPTX/forward-ld-param.ll   | 20 +++++++----------
 llvm/test/CodeGen/NVPTX/i128-array.ll         | 14 ++++++------
 llvm/test/CodeGen/NVPTX/lower-args.ll         | 22 +++++++++----------
 llvm/test/CodeGen/NVPTX/variadics-backend.ll  | 20 ++++++++---------
 .../Inputs/nvptx-basic.ll.expected            |  2 +-
 6 files changed, 39 insertions(+), 45 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index 5161a682fb01a..6dc927774eff8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -409,11 +409,11 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
     Worklist.pop();
 
     for (User *CurUser : Ctx.InitialVal->users()) {
-      if (auto *I = dyn_cast<LoadInst>(CurUser)) {
+      if (auto *I = dyn_cast<LoadInst>(CurUser))
         Loads.push_back({I, Ctx.Offset});
-      } else if (isa<BitCastInst>(CurUser) || isa<AddrSpaceCastInst>(CurUser)) {
+      else if (isa<BitCastInst>(CurUser) || isa<AddrSpaceCastInst>(CurUser))
         Worklist.push({cast<Instruction>(CurUser), Ctx.Offset});
-      } else if (auto *I = dyn_cast<GetElementPtrInst>(CurUser)) {
+      else if (auto *I = dyn_cast<GetElementPtrInst>(CurUser)) {
         APInt OffsetAccumulated =
             APInt::getZero(DL.getIndexSizeInBits(ADDRESS_SPACE_PARAM));
 
diff --git a/llvm/test/CodeGen/NVPTX/forward-ld-param.ll b/llvm/test/CodeGen/NVPTX/forward-ld-param.ll
index c4e56d197edc0..5bf2a84b0013a 100644
--- a/llvm/test/CodeGen/NVPTX/forward-ld-param.ll
+++ b/llvm/test/CodeGen/NVPTX/forward-ld-param.ll
@@ -7,7 +7,7 @@ define i32 @test_ld_param_const(ptr byval(i32) %a) {
 ; CHECK-LABEL: test_ld_param_const(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b32 %r<2>;
-; CHECK-NEXT:    .reg .b64 %rd<4>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.u32 %r1, [test_ld_param_const_param_0+4];
@@ -22,15 +22,13 @@ define i32 @test_ld_param_non_const(ptr byval([10 x i32]) %a, i32 %b) {
 ; CHECK-LABEL: test_ld_param_non_const(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b32 %r<2>;
-; CHECK-NEXT:    .reg .b64 %rd<6>;
+; CHECK-NEXT:    .reg .b64 %rd<4>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    mov.b64 %rd1, test_ld_param_non_const_param_0;
-; CHECK-NEXT:    cvta.local.u64 %rd2, %rd1;
-; CHECK-NEXT:    cvta.to.local.u64 %rd3, %rd2;
-; CHECK-NEXT:    ld.param.s32 %rd4, [test_ld_param_non_const_param_1];
-; CHECK-NEXT:    add.s64 %rd5, %rd3, %rd4;
-; CHECK-NEXT:    ld.local.u32 %r1, [%rd5];
+; CHECK-NEXT:    ld.param.s32 %rd2, [test_ld_param_non_const_param_1];
+; CHECK-NEXT:    add.s64 %rd3, %rd1, %rd2;
+; CHECK-NEXT:    ld.local.u32 %r1, [%rd3];
 ; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
 ; CHECK-NEXT:    ret;
   %p2 = getelementptr i8, ptr %a, i32 %b
@@ -89,15 +87,13 @@ define i32 @test_modify_param(ptr byval([10 x i32]) %a, i32 %b, i32 %c ) {
 ; CHECK-LABEL: test_modify_param(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b32 %r<3>;
-; CHECK-NEXT:    .reg .b64 %rd<4>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    mov.b64 %rd1, test_modify_param_param_0;
-; CHECK-NEXT:    cvta.local.u64 %rd2, %rd1;
-; CHECK-NEXT:    cvta.to.local.u64 %rd3, %rd2;
 ; CHECK-NEXT:    ld.param.u32 %r1, [test_modify_param_param_1];
 ; CHECK-NEXT:    ld.param.u32 %r2, [test_modify_param_param_2];
-; CHECK-NEXT:    st.local.u32 [%rd3+2], %r1;
+; CHECK-NEXT:    st.local.u32 [%rd1+2], %r1;
 ; CHECK-NEXT:    st.param.b32 [func_retval0], %r2;
 ; CHECK-NEXT:    ret;
   %p2 = getelementptr i8, ptr %a, i32 2
@@ -111,7 +107,7 @@ define i32 @test_multi_block(ptr byval([10 x i32]) %a, i1 %p) {
 ; CHECK-NEXT:    .reg .pred %p<3>;
 ; CHECK-NEXT:    .reg .b16 %rs<3>;
 ; CHECK-NEXT:    .reg .b32 %r<5>;
-; CHECK-NEXT:    .reg .b64 %rd<4>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.u8 %rs1, [test_multi_block_param_1];
diff --git a/llvm/test/CodeGen/NVPTX/i128-array.ll b/llvm/test/CodeGen/NVPTX/i128-array.ll
index baa18880de840..fb69224e87d11 100644
--- a/llvm/test/CodeGen/NVPTX/i128-array.ll
+++ b/llvm/test/CodeGen/NVPTX/i128-array.ll
@@ -27,15 +27,15 @@ define [2 x i128] @foo(i64 %a, i32 %b) {
 define [2 x i128] @foo2(ptr byval([2 x i128]) %a) {
 ; CHECK-LABEL: foo2(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b64 %rd<9>;
+; CHECK-NEXT:    .reg .b64 %rd<7>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.u64 %rd5, [foo2_param_0+8];
-; CHECK-NEXT:    ld.param.u64 %rd6, [foo2_param_0];
-; CHECK-NEXT:    ld.param.u64 %rd7, [foo2_param_0+24];
-; CHECK-NEXT:    ld.param.u64 %rd8, [foo2_param_0+16];
-; CHECK-NEXT:    st.param.v2.b64 [func_retval0], {%rd6, %rd5};
-; CHECK-NEXT:    st.param.v2.b64 [func_retval0+16], {%rd8, %rd7};
+; CHECK-NEXT:    ld.param.u64 %rd3, [foo2_param_0+8];
+; CHECK-NEXT:    ld.param.u64 %rd4, [foo2_param_0];
+; CHECK-NEXT:    ld.param.u64 %rd5, [foo2_param_0+24];
+; CHECK-NEXT:    ld.param.u64 %rd6, [foo2_param_0+16];
+; CHECK-NEXT:    st.param.v2.b64 [func_retval0], {%rd4, %rd3};
+; CHECK-NEXT:    st.param.v2.b64 [func_retval0+16], {%rd6, %rd5};
 ; CHECK-NEXT:    ret;
   %ptr0 = getelementptr [2 x i128], ptr %a, i64 0, i32 0
   %1 = load i128, i128* %ptr0
diff --git a/llvm/test/CodeGen/NVPTX/lower-args.ll b/llvm/test/CodeGen/NVPTX/lower-args.ll
index 42024a5be1c08..a1c0a86e9c4e4 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args.ll
@@ -33,30 +33,28 @@ define void @load_alignment(ptr nocapture readonly byval(%class.outer) align 8 %
 ; PTX-LABEL: load_alignment(
 ; PTX:       {
 ; PTX-NEXT:    .reg .b32 %r<4>;
-; PTX-NEXT:    .reg .b64 %rd<10>;
+; PTX-NEXT:    .reg .b64 %rd<8>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0: // %entry
 ; PTX-NEXT:    mov.b64 %rd1, load_alignment_param_0;
-; PTX-NEXT:    cvta.local.u64 %rd2, %rd1;
-; PTX-NEXT:    cvta.to.local.u64 %rd3, %rd2;
-; PTX-NEXT:    ld.local.u64 %rd4, [%rd3];
-; PTX-NEXT:    ld.local.u64 %rd5, [%rd3+8];
-; PTX-NEXT:    add.s64 %rd6, %rd3, 16;
-; PTX-NEXT:    cvta.local.u64 %rd7, %rd6;
-; PTX-NEXT:    ld.local.u32 %r1, [%rd3+16];
-; PTX-NEXT:    ld.u32 %r2, [%rd4];
+; PTX-NEXT:    ld.local.u64 %rd2, [%rd1];
+; PTX-NEXT:    ld.local.u64 %rd3, [%rd1+8];
+; PTX-NEXT:    add.s64 %rd4, %rd1, 16;
+; PTX-NEXT:    cvta.local.u64 %rd5, %rd4;
+; PTX-NEXT:    ld.local.u32 %r1, [%rd1+16];
+; PTX-NEXT:    ld.u32 %r2, [%rd2];
 ; PTX-NEXT:    add.s32 %r3, %r2, %r1;
-; PTX-NEXT:    st.u32 [%rd5], %r3;
+; PTX-NEXT:    st.u32 [%rd3], %r3;
 ; PTX-NEXT:    { // callseq 0, 0
 ; PTX-NEXT:    .param .b64 param0;
-; PTX-NEXT:    st.param.b64 [param0], %rd7;
+; PTX-NEXT:    st.param.b64 [param0], %rd5;
 ; PTX-NEXT:    .param .b64 retval0;
 ; PTX-NEXT:    call.uni (retval0),
 ; PTX-NEXT:    escape,
 ; PTX-NEXT:    (
 ; PTX-NEXT:    param0
 ; PTX-NEXT:    );
-; PTX-NEXT:    ld.param.b64 %rd8, [retval0];
+; PTX-NEXT:    ld.param.b64 %rd6, [retval0];
 ; PTX-NEXT:    } // callseq 0
 ; PTX-NEXT:    ret;
 entry:
diff --git a/llvm/test/CodeGen/NVPTX/variadics-backend.ll b/llvm/test/CodeGen/NVPTX/variadics-backend.ll
index eaf0ce58750b4..5a7e40ce898df 100644
--- a/llvm/test/CodeGen/NVPTX/variadics-backend.ll
+++ b/llvm/test/CodeGen/NVPTX/variadics-backend.ll
@@ -338,18 +338,18 @@ define dso_local i32 @variadics4(ptr noundef byval(%struct.S2) align 8 %first, .
 ; CHECK-PTX-LABEL: variadics4(
 ; CHECK-PTX:       {
 ; CHECK-PTX-NEXT:    .reg .b32 %r<2>;
-; CHECK-PTX-NEXT:    .reg .b64 %rd<12>;
+; CHECK-PTX-NEXT:    .reg .b64 %rd<10>;
 ; CHECK-PTX-EMPTY:
 ; CHECK-PTX-NEXT:  // %bb.0: // %entry
-; CHECK-PTX-NEXT:    ld.param.u64 %rd4, [variadics4_param_1];
-; CHECK-PTX-NEXT:    add.s64 %rd5, %rd4, 7;
-; CHECK-PTX-NEXT:    and.b64 %rd6, %rd5, -8;
-; CHECK-PTX-NEXT:    ld.u64 %rd7, [%rd6];
-; CHECK-PTX-NEXT:    ld.param.u64 %rd8, [variadics4_param_0];
-; CHECK-PTX-NEXT:    ld.param.u64 %rd9, [variadics4_param_0+8];
-; CHECK-PTX-NEXT:    add.s64 %rd10, %rd8, %rd9;
-; CHECK-PTX-NEXT:    add.s64 %rd11, %rd10, %rd7;
-; CHECK-PTX-NEXT:    cvt.u32.u64 %r1, %rd11;
+; CHECK-PTX-NEXT:    ld.param.u64 %rd2, [variadics4_param_1];
+; CHECK-PTX-NEXT:    add.s64 %rd3, %rd2, 7;
+; CHECK-PTX-NEXT:    and.b64 %rd4, %rd3, -8;
+; CHECK-PTX-NEXT:    ld.u64 %rd5, [%rd4];
+; CHECK-PTX-NEXT:    ld.param.u64 %rd6, [variadics4_param_0];
+; CHECK-PTX-NEXT:    ld.param.u64 %rd7, [variadics4_param_0+8];
+; CHECK-PTX-NEXT:    add.s64 %rd8, %rd6, %rd7;
+; CHECK-PTX-NEXT:    add.s64 %rd9, %rd8, %rd5;
+; CHECK-PTX-NEXT:    cvt.u32.u64 %r1, %rd9;
 ; CHECK-PTX-NEXT:    st.param.b32 [func_retval0], %r1;
 ; CHECK-PTX-NEXT:    ret;
 entry:
diff --git a/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected b/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected
index e982758da5b06..e470569bfae19 100644
--- a/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected
+++ b/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected
@@ -52,7 +52,7 @@ define internal fastcc [4 x i64] @callee_St8x4(ptr nocapture noundef readonly by
 ; CHECK-LABEL: callee_St8x4(
 ; CHECK:         // @callee_St8x4
 ; CHECK-NEXT:  {
-; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
 ; CHECK-NEXT:    .reg .b64 %rd<5>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:



More information about the llvm-commits mailing list