[llvm] [NVPTX][InferAS] assume alloca instructions are in local AS (PR #121710)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 31 12:09:19 PST 2025


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/121710

>From 544b7d1ff7575372a43d8de40abdf2286fd081b0 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Mon, 13 Jan 2025 00:35:34 +0000
Subject: [PATCH 1/4] scratch

---
 .../Target/NVPTX/NVPTXTargetTransformInfo.cpp |   8 ++
 .../Target/NVPTX/NVPTXTargetTransformInfo.h   |   1 +
 llvm/test/CodeGen/NVPTX/local-stack-frame.ll  |   8 +-
 .../CodeGen/NVPTX/lower-args-gridconstant.ll  |  31 +++---
 llvm/test/CodeGen/NVPTX/lower-args.ll         |  13 ++-
 llvm/test/CodeGen/NVPTX/variadics-backend.ll  | 104 +++++++++---------
 .../InferAddressSpaces/NVPTX/alloca.ll        |  17 +++
 .../Inputs/nvptx-basic.ll.expected            |  33 +++---
 8 files changed, 125 insertions(+), 90 deletions(-)
 create mode 100644 llvm/test/Transforms/InferAddressSpaces/NVPTX/alloca.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 85e99d7fe97a26d..e216f09c02d92ce 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -20,6 +20,7 @@
 #include "llvm/IR/Value.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/NVPTXAddrSpace.h"
 #include "llvm/Transforms/InstCombine/InstCombiner.h"
 #include <optional>
 using namespace llvm;
@@ -564,6 +565,13 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
   return nullptr;
 }
 
+unsigned NVPTXTTIImpl::getAssumedAddrSpace(const Value *V) const {
+  if (isa<AllocaInst>(V))
+    return ADDRESS_SPACE_LOCAL;
+
+  return -1;
+}
+
 void NVPTXTTIImpl::collectKernelLaunchBounds(
     const Function &F,
     SmallVectorImpl<std::pair<StringRef, int64_t>> &LB) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index b0a846a9c7f960e..7f69d422e8b4b4c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -129,6 +129,7 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
 
   Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
                                           Value *NewV) const;
+  unsigned getAssumedAddrSpace(const Value *V) const;
 
   void collectKernelLaunchBounds(
       const Function &F,
diff --git a/llvm/test/CodeGen/NVPTX/local-stack-frame.ll b/llvm/test/CodeGen/NVPTX/local-stack-frame.ll
index f21ff974a2c6bb6..7202e20628fe735 100644
--- a/llvm/test/CodeGen/NVPTX/local-stack-frame.ll
+++ b/llvm/test/CodeGen/NVPTX/local-stack-frame.ll
@@ -6,13 +6,13 @@
 ; Ensure we access the local stack properly
 
 ; PTX32:        mov.u32          %SPL, __local_depot{{[0-9]+}};
-; PTX32:        cvta.local.u32   %SP, %SPL;
 ; PTX32:        ld.param.u32     %r{{[0-9]+}}, [foo_param_0];
-; PTX32:        st.volatile.u32  [%SP], %r{{[0-9]+}};
+; PTX32:        add.u32          %r[[SP_REG:[0-9]+]], %SPL, 0;
+; PTX32:        st.local.u32  [%r[[SP_REG]]], %r{{[0-9]+}};
 ; PTX64:        mov.u64          %SPL, __local_depot{{[0-9]+}};
-; PTX64:        cvta.local.u64   %SP, %SPL;
 ; PTX64:        ld.param.u32     %r{{[0-9]+}}, [foo_param_0];
-; PTX64:        st.volatile.u32  [%SP], %r{{[0-9]+}};
+; PTX64:        add.u64          %rd[[SP_REG:[0-9]+]], %SPL, 0;
+; PTX64:        st.local.u32  [%rd[[SP_REG]]], %r{{[0-9]+}};
 define void @foo(i32 %a) {
   %local = alloca i32, align 4
   store volatile i32 %a, ptr %local
diff --git a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
index 28be5d7adbf8a00..fe15be5663be195 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
@@ -29,7 +29,7 @@ define dso_local noundef i32 @non_kernel_function(ptr nocapture noundef readonly
 ; PTX-NEXT:    .reg .pred %p<2>;
 ; PTX-NEXT:    .reg .b16 %rs<3>;
 ; PTX-NEXT:    .reg .b32 %r<11>;
-; PTX-NEXT:    .reg .b64 %rd<9>;
+; PTX-NEXT:    .reg .b64 %rd<10>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0: // %entry
 ; PTX-NEXT:    mov.u64 %SPL, __local_depot0;
@@ -37,23 +37,24 @@ define dso_local noundef i32 @non_kernel_function(ptr nocapture noundef readonly
 ; PTX-NEXT:    ld.param.u8 %rs1, [non_kernel_function_param_1];
 ; PTX-NEXT:    and.b16 %rs2, %rs1, 1;
 ; PTX-NEXT:    setp.eq.b16 %p1, %rs2, 1;
-; PTX-NEXT:    ld.param.s32 %rd1, [non_kernel_function_param_2];
-; PTX-NEXT:    ld.param.u64 %rd2, [non_kernel_function_param_0+8];
-; PTX-NEXT:    st.u64 [%SP+8], %rd2;
-; PTX-NEXT:    ld.param.u64 %rd3, [non_kernel_function_param_0];
-; PTX-NEXT:    st.u64 [%SP], %rd3;
-; PTX-NEXT:    mov.u64 %rd4, gi;
-; PTX-NEXT:    cvta.global.u64 %rd5, %rd4;
-; PTX-NEXT:    add.u64 %rd6, %SP, 0;
-; PTX-NEXT:    selp.b64 %rd7, %rd6, %rd5, %p1;
-; PTX-NEXT:    add.s64 %rd8, %rd7, %rd1;
-; PTX-NEXT:    ld.u8 %r1, [%rd8];
-; PTX-NEXT:    ld.u8 %r2, [%rd8+1];
+; PTX-NEXT:    add.u64 %rd1, %SP, 0;
+; PTX-NEXT:    add.u64 %rd2, %SPL, 0;
+; PTX-NEXT:    ld.param.s32 %rd3, [non_kernel_function_param_2];
+; PTX-NEXT:    ld.param.u64 %rd4, [non_kernel_function_param_0+8];
+; PTX-NEXT:    st.local.u64 [%rd2+8], %rd4;
+; PTX-NEXT:    ld.param.u64 %rd5, [non_kernel_function_param_0];
+; PTX-NEXT:    st.local.u64 [%rd2], %rd5;
+; PTX-NEXT:    mov.u64 %rd6, gi;
+; PTX-NEXT:    cvta.global.u64 %rd7, %rd6;
+; PTX-NEXT:    selp.b64 %rd8, %rd1, %rd7, %p1;
+; PTX-NEXT:    add.s64 %rd9, %rd8, %rd3;
+; PTX-NEXT:    ld.u8 %r1, [%rd9];
+; PTX-NEXT:    ld.u8 %r2, [%rd9+1];
 ; PTX-NEXT:    shl.b32 %r3, %r2, 8;
 ; PTX-NEXT:    or.b32 %r4, %r3, %r1;
-; PTX-NEXT:    ld.u8 %r5, [%rd8+2];
+; PTX-NEXT:    ld.u8 %r5, [%rd9+2];
 ; PTX-NEXT:    shl.b32 %r6, %r5, 16;
-; PTX-NEXT:    ld.u8 %r7, [%rd8+3];
+; PTX-NEXT:    ld.u8 %r7, [%rd9+3];
 ; PTX-NEXT:    shl.b32 %r8, %r7, 24;
 ; PTX-NEXT:    or.b32 %r9, %r8, %r6;
 ; PTX-NEXT:    or.b32 %r10, %r9, %r4;
diff --git a/llvm/test/CodeGen/NVPTX/lower-args.ll b/llvm/test/CodeGen/NVPTX/lower-args.ll
index 269bba75dc5fb33..06d8caeda69e152 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args.ll
@@ -40,24 +40,25 @@ define void @load_padding(ptr nocapture readonly byval(%class.padded) %arg) {
 ; PTX-NEXT:    .local .align 8 .b8 __local_depot1[8];
 ; PTX-NEXT:    .reg .b64 %SP;
 ; PTX-NEXT:    .reg .b64 %SPL;
-; PTX-NEXT:    .reg .b64 %rd<5>;
+; PTX-NEXT:    .reg .b64 %rd<6>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
 ; PTX-NEXT:    mov.u64 %SPL, __local_depot1;
 ; PTX-NEXT:    cvta.local.u64 %SP, %SPL;
-; PTX-NEXT:    ld.param.u64 %rd1, [load_padding_param_0];
-; PTX-NEXT:    st.u64 [%SP], %rd1;
-; PTX-NEXT:    add.u64 %rd2, %SP, 0;
+; PTX-NEXT:    add.u64         %rd1, %SP, 0;
+; PTX-NEXT:    add.u64         %rd2, %SPL, 0;
+; PTX-NEXT:    ld.param.u64 %rd3, [load_padding_param_0];
+; PTX-NEXT:    st.local.u64 [%rd2], %rd3;
 ; PTX-NEXT:    { // callseq 1, 0
 ; PTX-NEXT:    .param .b64 param0;
-; PTX-NEXT:    st.param.b64 [param0], %rd2;
+; PTX-NEXT:    st.param.b64 [param0], %rd1;
 ; PTX-NEXT:    .param .b64 retval0;
 ; PTX-NEXT:    call.uni (retval0),
 ; PTX-NEXT:    escape,
 ; PTX-NEXT:    (
 ; PTX-NEXT:    param0
 ; PTX-NEXT:    );
-; PTX-NEXT:    ld.param.b64 %rd3, [retval0];
+; PTX-NEXT:    ld.param.b64 %rd4, [retval0];
 ; PTX-NEXT:    } // callseq 1
 ; PTX-NEXT:    ret;
   %tmp = call ptr @escape(ptr nonnull align 16 %arg)
diff --git a/llvm/test/CodeGen/NVPTX/variadics-backend.ll b/llvm/test/CodeGen/NVPTX/variadics-backend.ll
index f7ed690efabcf30..f5c1e238f553a54 100644
--- a/llvm/test/CodeGen/NVPTX/variadics-backend.ll
+++ b/llvm/test/CodeGen/NVPTX/variadics-backend.ll
@@ -148,35 +148,34 @@ entry:
 define dso_local i32 @variadics2(i32 noundef %first, ...) {
 ; CHECK-PTX-LABEL: variadics2(
 ; CHECK-PTX:       {
-; CHECK-PTX-NEXT:    .local .align 2 .b8 __local_depot2[4];
+; CHECK-PTX-NEXT:    .local .align 1 .b8 __local_depot2[3];
 ; CHECK-PTX-NEXT:    .reg .b64 %SP;
 ; CHECK-PTX-NEXT:    .reg .b64 %SPL;
-; CHECK-PTX-NEXT:    .reg .b16 %rs<6>;
+; CHECK-PTX-NEXT:    .reg .b16 %rs<4>;
 ; CHECK-PTX-NEXT:    .reg .b32 %r<7>;
-; CHECK-PTX-NEXT:    .reg .b64 %rd<7>;
+; CHECK-PTX-NEXT:    .reg .b64 %rd<9>;
 ; CHECK-PTX-EMPTY:
 ; CHECK-PTX-NEXT:  // %bb.0: // %entry
 ; CHECK-PTX-NEXT:    mov.u64 %SPL, __local_depot2;
-; CHECK-PTX-NEXT:    cvta.local.u64 %SP, %SPL;
 ; CHECK-PTX-NEXT:    ld.param.u32 %r1, [variadics2_param_0];
 ; CHECK-PTX-NEXT:    ld.param.u64 %rd1, [variadics2_param_1];
-; CHECK-PTX-NEXT:    add.s64 %rd2, %rd1, 7;
-; CHECK-PTX-NEXT:    and.b64 %rd3, %rd2, -8;
-; CHECK-PTX-NEXT:    ld.u32 %r2, [%rd3];
-; CHECK-PTX-NEXT:    ld.s8 %r3, [%rd3+4];
-; CHECK-PTX-NEXT:    ld.u8 %rs1, [%rd3+7];
-; CHECK-PTX-NEXT:    st.u8 [%SP+2], %rs1;
-; CHECK-PTX-NEXT:    ld.u8 %rs2, [%rd3+5];
-; CHECK-PTX-NEXT:    ld.u8 %rs3, [%rd3+6];
-; CHECK-PTX-NEXT:    shl.b16 %rs4, %rs3, 8;
-; CHECK-PTX-NEXT:    or.b16 %rs5, %rs4, %rs2;
-; CHECK-PTX-NEXT:    st.u16 [%SP], %rs5;
-; CHECK-PTX-NEXT:    ld.u64 %rd4, [%rd3+8];
+; CHECK-PTX-NEXT:    add.u64 %rd3, %SPL, 0;
+; CHECK-PTX-NEXT:    add.s64 %rd4, %rd1, 7;
+; CHECK-PTX-NEXT:    and.b64 %rd5, %rd4, -8;
+; CHECK-PTX-NEXT:    ld.u32 %r2, [%rd5];
+; CHECK-PTX-NEXT:    ld.s8 %r3, [%rd5+4];
+; CHECK-PTX-NEXT:    ld.u8 %rs1, [%rd5+7];
+; CHECK-PTX-NEXT:    st.local.u8 [%rd3+2], %rs1;
+; CHECK-PTX-NEXT:    ld.u8 %rs2, [%rd5+6];
+; CHECK-PTX-NEXT:    st.local.u8 [%rd3+1], %rs2;
+; CHECK-PTX-NEXT:    ld.u8 %rs3, [%rd5+5];
+; CHECK-PTX-NEXT:    st.local.u8 [%rd3], %rs3;
+; CHECK-PTX-NEXT:    ld.u64 %rd6, [%rd5+8];
 ; CHECK-PTX-NEXT:    add.s32 %r4, %r1, %r2;
 ; CHECK-PTX-NEXT:    add.s32 %r5, %r4, %r3;
-; CHECK-PTX-NEXT:    cvt.u64.u32 %rd5, %r5;
-; CHECK-PTX-NEXT:    add.s64 %rd6, %rd5, %rd4;
-; CHECK-PTX-NEXT:    cvt.u32.u64 %r6, %rd6;
+; CHECK-PTX-NEXT:    cvt.u64.u32 %rd7, %r5;
+; CHECK-PTX-NEXT:    add.s64 %rd8, %rd7, %rd6;
+; CHECK-PTX-NEXT:    cvt.u32.u64 %r6, %rd8;
 ; CHECK-PTX-NEXT:    st.param.b32 [func_retval0], %r6;
 ; CHECK-PTX-NEXT:    ret;
 entry:
@@ -213,39 +212,39 @@ define dso_local i32 @bar() {
 ; CHECK-PTX-NEXT:    .local .align 8 .b8 __local_depot3[24];
 ; CHECK-PTX-NEXT:    .reg .b64 %SP;
 ; CHECK-PTX-NEXT:    .reg .b64 %SPL;
-; CHECK-PTX-NEXT:    .reg .b16 %rs<10>;
+; CHECK-PTX-NEXT:    .reg .b16 %rs<8>;
 ; CHECK-PTX-NEXT:    .reg .b32 %r<4>;
-; CHECK-PTX-NEXT:    .reg .b64 %rd<7>;
+; CHECK-PTX-NEXT:    .reg .b64 %rd<9>;
 ; CHECK-PTX-EMPTY:
 ; CHECK-PTX-NEXT:  // %bb.0: // %entry
 ; CHECK-PTX-NEXT:    mov.u64 %SPL, __local_depot3;
 ; CHECK-PTX-NEXT:    cvta.local.u64 %SP, %SPL;
-; CHECK-PTX-NEXT:    mov.u64 %rd1, __const_$_bar_$_s1;
-; CHECK-PTX-NEXT:    add.s64 %rd2, %rd1, 7;
-; CHECK-PTX-NEXT:    ld.global.nc.u8 %rs1, [%rd2];
+; CHECK-PTX-NEXT:    add.u64 %rd2, %SPL, 0;
+; CHECK-PTX-NEXT:    mov.u64 %rd3, __const_$_bar_$_s1;
+; CHECK-PTX-NEXT:    add.s64 %rd4, %rd3, 7;
+; CHECK-PTX-NEXT:    ld.global.nc.u8 %rs1, [%rd4];
 ; CHECK-PTX-NEXT:    cvt.u16.u8 %rs2, %rs1;
-; CHECK-PTX-NEXT:    st.u8 [%SP+2], %rs2;
-; CHECK-PTX-NEXT:    add.s64 %rd3, %rd1, 5;
-; CHECK-PTX-NEXT:    ld.global.nc.u8 %rs3, [%rd3];
+; CHECK-PTX-NEXT:    st.local.u8 [%rd2+2], %rs2;
+; CHECK-PTX-NEXT:    add.s64 %rd5, %rd3, 6;
+; CHECK-PTX-NEXT:    ld.global.nc.u8 %rs3, [%rd5];
 ; CHECK-PTX-NEXT:    cvt.u16.u8 %rs4, %rs3;
-; CHECK-PTX-NEXT:    add.s64 %rd4, %rd1, 6;
-; CHECK-PTX-NEXT:    ld.global.nc.u8 %rs5, [%rd4];
+; CHECK-PTX-NEXT:    st.local.u8 [%rd2+1], %rs4;
+; CHECK-PTX-NEXT:    add.s64 %rd6, %rd3, 5;
+; CHECK-PTX-NEXT:    ld.global.nc.u8 %rs5, [%rd6];
 ; CHECK-PTX-NEXT:    cvt.u16.u8 %rs6, %rs5;
-; CHECK-PTX-NEXT:    shl.b16 %rs7, %rs6, 8;
-; CHECK-PTX-NEXT:    or.b16 %rs8, %rs7, %rs4;
-; CHECK-PTX-NEXT:    st.u16 [%SP], %rs8;
+; CHECK-PTX-NEXT:    st.local.u8 [%rd2], %rs6;
 ; CHECK-PTX-NEXT:    mov.b32 %r1, 1;
 ; CHECK-PTX-NEXT:    st.u32 [%SP+8], %r1;
-; CHECK-PTX-NEXT:    mov.b16 %rs9, 1;
-; CHECK-PTX-NEXT:    st.u8 [%SP+12], %rs9;
-; CHECK-PTX-NEXT:    mov.b64 %rd5, 1;
-; CHECK-PTX-NEXT:    st.u64 [%SP+16], %rd5;
-; CHECK-PTX-NEXT:    add.u64 %rd6, %SP, 8;
+; CHECK-PTX-NEXT:    mov.b16 %rs7, 1;
+; CHECK-PTX-NEXT:    st.u8 [%SP+12], %rs7;
+; CHECK-PTX-NEXT:    mov.b64 %rd7, 1;
+; CHECK-PTX-NEXT:    st.u64 [%SP+16], %rd7;
+; CHECK-PTX-NEXT:    add.u64 %rd8, %SP, 8;
 ; CHECK-PTX-NEXT:    { // callseq 1, 0
 ; CHECK-PTX-NEXT:    .param .b32 param0;
 ; CHECK-PTX-NEXT:    st.param.b32 [param0], 1;
 ; CHECK-PTX-NEXT:    .param .b64 param1;
-; CHECK-PTX-NEXT:    st.param.b64 [param1], %rd6;
+; CHECK-PTX-NEXT:    st.param.b64 [param1], %rd8;
 ; CHECK-PTX-NEXT:    .param .b32 retval0;
 ; CHECK-PTX-NEXT:    call.uni (retval0),
 ; CHECK-PTX-NEXT:    variadics2,
@@ -384,26 +383,29 @@ define dso_local void @qux() {
 ; CHECK-PTX-NEXT:    .reg .b64 %SP;
 ; CHECK-PTX-NEXT:    .reg .b64 %SPL;
 ; CHECK-PTX-NEXT:    .reg .b32 %r<3>;
-; CHECK-PTX-NEXT:    .reg .b64 %rd<7>;
+; CHECK-PTX-NEXT:    .reg .b64 %rd<11>;
 ; CHECK-PTX-EMPTY:
 ; CHECK-PTX-NEXT:  // %bb.0: // %entry
 ; CHECK-PTX-NEXT:    mov.u64 %SPL, __local_depot7;
 ; CHECK-PTX-NEXT:    cvta.local.u64 %SP, %SPL;
-; CHECK-PTX-NEXT:    ld.global.nc.u64 %rd1, [__const_$_qux_$_s];
-; CHECK-PTX-NEXT:    st.u64 [%SP], %rd1;
-; CHECK-PTX-NEXT:    mov.u64 %rd2, __const_$_qux_$_s;
-; CHECK-PTX-NEXT:    add.s64 %rd3, %rd2, 8;
-; CHECK-PTX-NEXT:    ld.global.nc.u64 %rd4, [%rd3];
-; CHECK-PTX-NEXT:    st.u64 [%SP+8], %rd4;
-; CHECK-PTX-NEXT:    mov.b64 %rd5, 1;
-; CHECK-PTX-NEXT:    st.u64 [%SP+16], %rd5;
-; CHECK-PTX-NEXT:    add.u64 %rd6, %SP, 16;
+; CHECK-PTX-NEXT:    add.u64 %rd2, %SPL, 0;
+; CHECK-PTX-NEXT:    ld.global.nc.u64 %rd3, [__const_$_qux_$_s];
+; CHECK-PTX-NEXT:    st.local.u64 [%rd2], %rd3;
+; CHECK-PTX-NEXT:    mov.u64 %rd4, __const_$_qux_$_s;
+; CHECK-PTX-NEXT:    add.s64 %rd5, %rd4, 8;
+; CHECK-PTX-NEXT:    ld.global.nc.u64 %rd6, [%rd5];
+; CHECK-PTX-NEXT:    st.local.u64 [%rd2+8], %rd6;
+; CHECK-PTX-NEXT:    mov.b64 %rd7, 1;
+; CHECK-PTX-NEXT:    st.u64 [%SP+16], %rd7;
+; CHECK-PTX-NEXT:    ld.u64 %rd8, [%SP];
+; CHECK-PTX-NEXT:    ld.u64 %rd9, [%SP+8];
+; CHECK-PTX-NEXT:    add.u64 %rd10, %SP, 16;
 ; CHECK-PTX-NEXT:    { // callseq 3, 0
 ; CHECK-PTX-NEXT:    .param .align 8 .b8 param0[16];
-; CHECK-PTX-NEXT:    st.param.b64 [param0], %rd1;
-; CHECK-PTX-NEXT:    st.param.b64 [param0+8], %rd4;
+; CHECK-PTX-NEXT:    st.param.b64 [param0], %rd8;
+; CHECK-PTX-NEXT:    st.param.b64 [param0+8], %rd9;
 ; CHECK-PTX-NEXT:    .param .b64 param1;
-; CHECK-PTX-NEXT:    st.param.b64 [param1], %rd6;
+; CHECK-PTX-NEXT:    st.param.b64 [param1], %rd10;
 ; CHECK-PTX-NEXT:    .param .b32 retval0;
 ; CHECK-PTX-NEXT:    call.uni (retval0),
 ; CHECK-PTX-NEXT:    variadics4,
diff --git a/llvm/test/Transforms/InferAddressSpaces/NVPTX/alloca.ll b/llvm/test/Transforms/InferAddressSpaces/NVPTX/alloca.ll
new file mode 100644
index 000000000000000..fa063cdf8d80543
--- /dev/null
+++ b/llvm/test/Transforms/InferAddressSpaces/NVPTX/alloca.ll
@@ -0,0 +1,17 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=infer-address-spaces %s | FileCheck %s
+
+target triple = "nvptx64-nvidia-cuda"
+
+
+define float @load_alloca() {
+; CHECK-LABEL: define float @load_alloca() {
+; CHECK-NEXT:    [[ADDR:%.*]] = alloca float, align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast ptr [[ADDR]] to ptr addrspace(5)
+; CHECK-NEXT:    [[VAL:%.*]] = load float, ptr addrspace(5) [[TMP1]], align 4
+; CHECK-NEXT:    ret float [[VAL]]
+;
+  %addr = alloca float
+  %val = load float, ptr %addr
+  ret float %val
+}
diff --git a/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected b/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected
index a64364019de15e9..b0346f4db5ba194 100644
--- a/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected
+++ b/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected
@@ -9,38 +9,43 @@ define dso_local void @caller_St8x4(ptr nocapture noundef readonly byval(%struct
 ; CHECK-NEXT:    .local .align 8 .b8 __local_depot0[32];
 ; CHECK-NEXT:    .reg .b32 %SP;
 ; CHECK-NEXT:    .reg .b32 %SPL;
-; CHECK-NEXT:    .reg .b32 %r<2>;
-; CHECK-NEXT:    .reg .b64 %rd<13>;
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-NEXT:    .reg .b64 %rd<17>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    mov.u32 %SPL, __local_depot0;
 ; CHECK-NEXT:    cvta.local.u32 %SP, %SPL;
 ; CHECK-NEXT:    ld.param.u32 %r1, [caller_St8x4_param_1];
+; CHECK-NEXT:    add.u32 %r3, %SPL, 0;
 ; CHECK-NEXT:    ld.param.u64 %rd1, [caller_St8x4_param_0+24];
-; CHECK-NEXT:    st.u64 [%SP+24], %rd1;
+; CHECK-NEXT:    st.local.u64 [%r3+24], %rd1;
 ; CHECK-NEXT:    ld.param.u64 %rd2, [caller_St8x4_param_0+16];
-; CHECK-NEXT:    st.u64 [%SP+16], %rd2;
+; CHECK-NEXT:    st.local.u64 [%r3+16], %rd2;
 ; CHECK-NEXT:    ld.param.u64 %rd3, [caller_St8x4_param_0+8];
-; CHECK-NEXT:    st.u64 [%SP+8], %rd3;
+; CHECK-NEXT:    st.local.u64 [%r3+8], %rd3;
 ; CHECK-NEXT:    ld.param.u64 %rd4, [caller_St8x4_param_0];
-; CHECK-NEXT:    st.u64 [%SP], %rd4;
+; CHECK-NEXT:    st.local.u64 [%r3], %rd4;
+; CHECK-NEXT:    ld.u64 %rd5, [%SP+8];
+; CHECK-NEXT:    ld.u64 %rd6, [%SP];
+; CHECK-NEXT:    ld.u64 %rd7, [%SP+24];
+; CHECK-NEXT:    ld.u64 %rd8, [%SP+16];
 ; CHECK-NEXT:    { // callseq 0, 0
 ; CHECK-NEXT:    .param .align 16 .b8 param0[32];
-; CHECK-NEXT:    st.param.v2.b64 [param0], {%rd4, %rd3};
-; CHECK-NEXT:    st.param.v2.b64 [param0+16], {%rd2, %rd1};
+; CHECK-NEXT:    st.param.v2.b64 [param0], {%rd6, %rd5};
+; CHECK-NEXT:    st.param.v2.b64 [param0+16], {%rd8, %rd7};
 ; CHECK-NEXT:    .param .align 16 .b8 retval0[32];
 ; CHECK-NEXT:    call.uni (retval0),
 ; CHECK-NEXT:    callee_St8x4,
 ; CHECK-NEXT:    (
 ; CHECK-NEXT:    param0
 ; CHECK-NEXT:    );
-; CHECK-NEXT:    ld.param.v2.b64 {%rd5, %rd6}, [retval0];
-; CHECK-NEXT:    ld.param.v2.b64 {%rd7, %rd8}, [retval0+16];
+; CHECK-NEXT:    ld.param.v2.b64 {%rd9, %rd10}, [retval0];
+; CHECK-NEXT:    ld.param.v2.b64 {%rd11, %rd12}, [retval0+16];
 ; CHECK-NEXT:    } // callseq 0
-; CHECK-NEXT:    st.u64 [%r1], %rd5;
-; CHECK-NEXT:    st.u64 [%r1+8], %rd6;
-; CHECK-NEXT:    st.u64 [%r1+16], %rd7;
-; CHECK-NEXT:    st.u64 [%r1+24], %rd8;
+; CHECK-NEXT:    st.u64 [%r1], %rd9;
+; CHECK-NEXT:    st.u64 [%r1+8], %rd10;
+; CHECK-NEXT:    st.u64 [%r1+16], %rd11;
+; CHECK-NEXT:    st.u64 [%r1+24], %rd12;
 ; CHECK-NEXT:    ret;
   %call = tail call fastcc [4 x i64] @callee_St8x4(ptr noundef nonnull byval(%struct.St8x4) align 8 %in) #2
   %.fca.0.extract = extractvalue [4 x i64] %call, 0

>From 8a7217d6409a1520e792e05b61f16b3b4176811f Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Mon, 13 Jan 2025 23:29:25 +0000
Subject: [PATCH 2/4] [SDAG] Fixups required for InferAS change

---
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp   | 48 +++++++++++--------
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   | 22 +++++++--
 llvm/test/CodeGen/NVPTX/indirect_byval.ll     | 20 ++++----
 llvm/test/CodeGen/NVPTX/variadics-backend.ll  |  4 +-
 .../Inputs/nvptx-basic.ll.expected            | 23 ++++-----
 5 files changed, 67 insertions(+), 50 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index ac8ce05724750cb..da471d5c3c42602 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -25,6 +25,7 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Target/TargetIntrinsicInfo.h"
+#include <optional>
 
 using namespace llvm;
 
@@ -341,29 +342,34 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
   return true;
 }
 
-static unsigned int getCodeAddrSpace(MemSDNode *N) {
-  const Value *Src = N->getMemOperand()->getValue();
-
-  if (!Src)
+static std::optional<unsigned> convertAS(unsigned AS) {
+  switch (AS) {
+  case llvm::ADDRESS_SPACE_LOCAL:
+    return NVPTX::AddressSpace::Local;
+  case llvm::ADDRESS_SPACE_GLOBAL:
+    return NVPTX::AddressSpace::Global;
+  case llvm::ADDRESS_SPACE_SHARED:
+    return NVPTX::AddressSpace::Shared;
+  case llvm::ADDRESS_SPACE_GENERIC:
     return NVPTX::AddressSpace::Generic;
-
-  if (auto *PT = dyn_cast<PointerType>(Src->getType())) {
-    switch (PT->getAddressSpace()) {
-    case llvm::ADDRESS_SPACE_LOCAL:
-      return NVPTX::AddressSpace::Local;
-    case llvm::ADDRESS_SPACE_GLOBAL:
-      return NVPTX::AddressSpace::Global;
-    case llvm::ADDRESS_SPACE_SHARED:
-      return NVPTX::AddressSpace::Shared;
-    case llvm::ADDRESS_SPACE_GENERIC:
-      return NVPTX::AddressSpace::Generic;
-    case llvm::ADDRESS_SPACE_PARAM:
-      return NVPTX::AddressSpace::Param;
-    case llvm::ADDRESS_SPACE_CONST:
-      return NVPTX::AddressSpace::Const;
-    default: break;
-    }
+  case llvm::ADDRESS_SPACE_PARAM:
+    return NVPTX::AddressSpace::Param;
+  case llvm::ADDRESS_SPACE_CONST:
+    return NVPTX::AddressSpace::Const;
+  default:
+    return std::nullopt;
   }
+}
+
+static unsigned int getCodeAddrSpace(const MemSDNode *N) {
+  if (const Value *Src = N->getMemOperand()->getValue())
+    if (auto *PT = dyn_cast<PointerType>(Src->getType()))
+      if (auto AS = convertAS(PT->getAddressSpace()))
+        return AS.value();
+
+  if (auto AS = convertAS(N->getMemOperand()->getAddrSpace()))
+    return AS.value();
+
   return NVPTX::AddressSpace::Generic;
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 773c97f7b4dc0ff..18a8212a2918757 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1408,6 +1408,19 @@ static bool shouldConvertToIndirectCall(const CallBase *CB,
   return false;
 }
 
+static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
+                                      const DataLayout &DL,
+                                      const TargetLowering &TL) {
+  if (Ptr->getOpcode() == ISD::FrameIndex) {
+    auto Ty = TL.getPointerTy(DL, ADDRESS_SPACE_LOCAL);
+    Ptr = DAG.getAddrSpaceCast(SDLoc(), Ty, Ptr, ADDRESS_SPACE_GENERIC,
+                               ADDRESS_SPACE_LOCAL);
+
+    return MachinePointerInfo(ADDRESS_SPACE_LOCAL);
+  }
+  return MachinePointerInfo();
+}
+
 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                                        SmallVectorImpl<SDValue> &InVals) const {
 
@@ -1572,11 +1585,12 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       }
 
       if (IsByVal) {
-        auto PtrVT = getPointerTy(DL);
-        SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
+        auto MPI = refinePtrAS(StVal, DAG, DL, *this);
+        const EVT PtrVT = StVal.getValueType();
+        SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
                                       DAG.getConstant(CurOffset, dl, PtrVT));
-        StVal = DAG.getLoad(EltVT, dl, TempChain, srcAddr, MachinePointerInfo(),
-                            PartAlign);
+
+        StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign);
       } else if (ExtendIntegerParam) {
         assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
         // zext/sext to i32
diff --git a/llvm/test/CodeGen/NVPTX/indirect_byval.ll b/llvm/test/CodeGen/NVPTX/indirect_byval.ll
index d6c6e032f032fd5..3ae6300d8767d6e 100644
--- a/llvm/test/CodeGen/NVPTX/indirect_byval.ll
+++ b/llvm/test/CodeGen/NVPTX/indirect_byval.ll
@@ -17,19 +17,20 @@ define internal i32 @foo() {
 ; CHECK-NEXT:    .reg .b64 %SPL;
 ; CHECK-NEXT:    .reg .b16 %rs<2>;
 ; CHECK-NEXT:    .reg .b32 %r<3>;
-; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-NEXT:    .reg .b64 %rd<5>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0: // %entry
 ; CHECK-NEXT:    mov.u64 %SPL, __local_depot0;
 ; CHECK-NEXT:    cvta.local.u64 %SP, %SPL;
 ; CHECK-NEXT:    ld.global.u64 %rd1, [ptr];
-; CHECK-NEXT:    ld.u8 %rs1, [%SP+1];
-; CHECK-NEXT:    add.u64 %rd2, %SP, 0;
+; CHECK-NEXT:    add.u64 %rd3, %SPL, 1;
+; CHECK-NEXT:    ld.local.u8 %rs1, [%rd3];
+; CHECK-NEXT:    add.u64 %rd4, %SP, 0;
 ; CHECK-NEXT:    { // callseq 0, 0
 ; CHECK-NEXT:    .param .align 1 .b8 param0[1];
 ; CHECK-NEXT:    st.param.b8 [param0], %rs1;
 ; CHECK-NEXT:    .param .b64 param1;
-; CHECK-NEXT:    st.param.b64 [param1], %rd2;
+; CHECK-NEXT:    st.param.b64 [param1], %rd4;
 ; CHECK-NEXT:    .param .b32 retval0;
 ; CHECK-NEXT:    prototype_0 : .callprototype (.param .b32 _) _ (.param .align 1 .b8 _[1], .param .b64 _);
 ; CHECK-NEXT:    call (retval0),
@@ -59,19 +60,20 @@ define internal i32 @bar() {
 ; CHECK-NEXT:    .reg .b64 %SP;
 ; CHECK-NEXT:    .reg .b64 %SPL;
 ; CHECK-NEXT:    .reg .b32 %r<3>;
-; CHECK-NEXT:    .reg .b64 %rd<4>;
+; CHECK-NEXT:    .reg .b64 %rd<6>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0: // %entry
 ; CHECK-NEXT:    mov.u64 %SPL, __local_depot1;
 ; CHECK-NEXT:    cvta.local.u64 %SP, %SPL;
 ; CHECK-NEXT:    ld.global.u64 %rd1, [ptr];
-; CHECK-NEXT:    ld.u64 %rd2, [%SP+8];
-; CHECK-NEXT:    add.u64 %rd3, %SP, 0;
+; CHECK-NEXT:    add.u64 %rd3, %SPL, 8;
+; CHECK-NEXT:    ld.local.u64 %rd4, [%rd3];
+; CHECK-NEXT:    add.u64 %rd5, %SP, 0;
 ; CHECK-NEXT:    { // callseq 1, 0
 ; CHECK-NEXT:    .param .align 8 .b8 param0[8];
-; CHECK-NEXT:    st.param.b64 [param0], %rd2;
+; CHECK-NEXT:    st.param.b64 [param0], %rd4;
 ; CHECK-NEXT:    .param .b64 param1;
-; CHECK-NEXT:    st.param.b64 [param1], %rd3;
+; CHECK-NEXT:    st.param.b64 [param1], %rd5;
 ; CHECK-NEXT:    .param .b32 retval0;
 ; CHECK-NEXT:    prototype_1 : .callprototype (.param .b32 _) _ (.param .align 8 .b8 _[8], .param .b64 _);
 ; CHECK-NEXT:    call (retval0),
diff --git a/llvm/test/CodeGen/NVPTX/variadics-backend.ll b/llvm/test/CodeGen/NVPTX/variadics-backend.ll
index f5c1e238f553a54..c3296dd5298fc0f 100644
--- a/llvm/test/CodeGen/NVPTX/variadics-backend.ll
+++ b/llvm/test/CodeGen/NVPTX/variadics-backend.ll
@@ -397,8 +397,8 @@ define dso_local void @qux() {
 ; CHECK-PTX-NEXT:    st.local.u64 [%rd2+8], %rd6;
 ; CHECK-PTX-NEXT:    mov.b64 %rd7, 1;
 ; CHECK-PTX-NEXT:    st.u64 [%SP+16], %rd7;
-; CHECK-PTX-NEXT:    ld.u64 %rd8, [%SP];
-; CHECK-PTX-NEXT:    ld.u64 %rd9, [%SP+8];
+; CHECK-PTX-NEXT:    ld.local.u64 %rd8, [%rd2];
+; CHECK-PTX-NEXT:    ld.local.u64 %rd9, [%rd2+8];
 ; CHECK-PTX-NEXT:    add.u64 %rd10, %SP, 16;
 ; CHECK-PTX-NEXT:    { // callseq 3, 0
 ; CHECK-PTX-NEXT:    .param .align 8 .b8 param0[16];
diff --git a/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected b/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected
index b0346f4db5ba194..820ade631dd6405 100644
--- a/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected
+++ b/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected
@@ -10,11 +10,10 @@ define dso_local void @caller_St8x4(ptr nocapture noundef readonly byval(%struct
 ; CHECK-NEXT:    .reg .b32 %SP;
 ; CHECK-NEXT:    .reg .b32 %SPL;
 ; CHECK-NEXT:    .reg .b32 %r<4>;
-; CHECK-NEXT:    .reg .b64 %rd<17>;
+; CHECK-NEXT:    .reg .b64 %rd<13>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    mov.u32 %SPL, __local_depot0;
-; CHECK-NEXT:    cvta.local.u32 %SP, %SPL;
 ; CHECK-NEXT:    ld.param.u32 %r1, [caller_St8x4_param_1];
 ; CHECK-NEXT:    add.u32 %r3, %SPL, 0;
 ; CHECK-NEXT:    ld.param.u64 %rd1, [caller_St8x4_param_0+24];
@@ -25,27 +24,23 @@ define dso_local void @caller_St8x4(ptr nocapture noundef readonly byval(%struct
 ; CHECK-NEXT:    st.local.u64 [%r3+8], %rd3;
 ; CHECK-NEXT:    ld.param.u64 %rd4, [caller_St8x4_param_0];
 ; CHECK-NEXT:    st.local.u64 [%r3], %rd4;
-; CHECK-NEXT:    ld.u64 %rd5, [%SP+8];
-; CHECK-NEXT:    ld.u64 %rd6, [%SP];
-; CHECK-NEXT:    ld.u64 %rd7, [%SP+24];
-; CHECK-NEXT:    ld.u64 %rd8, [%SP+16];
 ; CHECK-NEXT:    { // callseq 0, 0
 ; CHECK-NEXT:    .param .align 16 .b8 param0[32];
-; CHECK-NEXT:    st.param.v2.b64 [param0], {%rd6, %rd5};
-; CHECK-NEXT:    st.param.v2.b64 [param0+16], {%rd8, %rd7};
+; CHECK-NEXT:    st.param.v2.b64 [param0], {%rd4, %rd3};
+; CHECK-NEXT:    st.param.v2.b64 [param0+16], {%rd2, %rd1};
 ; CHECK-NEXT:    .param .align 16 .b8 retval0[32];
 ; CHECK-NEXT:    call.uni (retval0),
 ; CHECK-NEXT:    callee_St8x4,
 ; CHECK-NEXT:    (
 ; CHECK-NEXT:    param0
 ; CHECK-NEXT:    );
-; CHECK-NEXT:    ld.param.v2.b64 {%rd9, %rd10}, [retval0];
-; CHECK-NEXT:    ld.param.v2.b64 {%rd11, %rd12}, [retval0+16];
+; CHECK-NEXT:    ld.param.v2.b64 {%rd5, %rd6}, [retval0];
+; CHECK-NEXT:    ld.param.v2.b64 {%rd7, %rd8}, [retval0+16];
 ; CHECK-NEXT:    } // callseq 0
-; CHECK-NEXT:    st.u64 [%r1], %rd9;
-; CHECK-NEXT:    st.u64 [%r1+8], %rd10;
-; CHECK-NEXT:    st.u64 [%r1+16], %rd11;
-; CHECK-NEXT:    st.u64 [%r1+24], %rd12;
+; CHECK-NEXT:    st.u64 [%r1], %rd5;
+; CHECK-NEXT:    st.u64 [%r1+8], %rd6;
+; CHECK-NEXT:    st.u64 [%r1+16], %rd7;
+; CHECK-NEXT:    st.u64 [%r1+24], %rd8;
 ; CHECK-NEXT:    ret;
   %call = tail call fastcc [4 x i64] @callee_St8x4(ptr noundef nonnull byval(%struct.St8x4) align 8 %in) #2
   %.fca.0.extract = extractvalue [4 x i64] %call, 0

>From 763cb2909bc22641fd2c71a512e038cf75eaed92 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Tue, 21 Jan 2025 21:03:16 +0000
Subject: [PATCH 3/4] address comments

---
 llvm/test/CodeGen/NVPTX/local-stack-frame.ll | 228 +++++++++++++++----
 1 file changed, 185 insertions(+), 43 deletions(-)

diff --git a/llvm/test/CodeGen/NVPTX/local-stack-frame.ll b/llvm/test/CodeGen/NVPTX/local-stack-frame.ll
index 7202e20628fe735..3523ffe6ae3cab0 100644
--- a/llvm/test/CodeGen/NVPTX/local-stack-frame.ll
+++ b/llvm/test/CodeGen/NVPTX/local-stack-frame.ll
@@ -1,3 +1,4 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
 ; RUN: llc < %s -mtriple=nvptx -mcpu=sm_20 -verify-machineinstrs | FileCheck %s --check-prefix=PTX32
 ; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -verify-machineinstrs | FileCheck %s --check-prefix=PTX64
 ; RUN: %if ptxas && !ptxas-12.0 %{ llc < %s -mtriple=nvptx -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %}
@@ -5,31 +6,91 @@
 
 ; Ensure we access the local stack properly
 
-; PTX32:        mov.u32          %SPL, __local_depot{{[0-9]+}};
-; PTX32:        ld.param.u32     %r{{[0-9]+}}, [foo_param_0];
-; PTX32:        add.u32          %r[[SP_REG:[0-9]+]], %SPL, 0;
-; PTX32:        st.local.u32  [%r[[SP_REG]]], %r{{[0-9]+}};
-; PTX64:        mov.u64          %SPL, __local_depot{{[0-9]+}};
-; PTX64:        ld.param.u32     %r{{[0-9]+}}, [foo_param_0];
-; PTX64:        add.u64          %rd[[SP_REG:[0-9]+]], %SPL, 0;
-; PTX64:        st.local.u32  [%rd[[SP_REG]]], %r{{[0-9]+}};
 define void @foo(i32 %a) {
+; PTX32-LABEL: foo(
+; PTX32:       {
+; PTX32-NEXT:    .local .align 4 .b8 __local_depot0[4];
+; PTX32-NEXT:    .reg .b32 %SP;
+; PTX32-NEXT:    .reg .b32 %SPL;
+; PTX32-NEXT:    .reg .b32 %r<4>;
+; PTX32-EMPTY:
+; PTX32-NEXT:  // %bb.0:
+; PTX32-NEXT:    mov.u32 %SPL, __local_depot0;
+; PTX32-NEXT:    ld.param.u32 %r1, [foo_param_0];
+; PTX32-NEXT:    add.u32 %r3, %SPL, 0;
+; PTX32-NEXT:    st.local.u32 [%r3], %r1;
+; PTX32-NEXT:    ret;
+;
+; PTX64-LABEL: foo(
+; PTX64:       {
+; PTX64-NEXT:    .local .align 4 .b8 __local_depot0[4];
+; PTX64-NEXT:    .reg .b64 %SP;
+; PTX64-NEXT:    .reg .b64 %SPL;
+; PTX64-NEXT:    .reg .b32 %r<2>;
+; PTX64-NEXT:    .reg .b64 %rd<3>;
+; PTX64-EMPTY:
+; PTX64-NEXT:  // %bb.0:
+; PTX64-NEXT:    mov.u64 %SPL, __local_depot0;
+; PTX64-NEXT:    ld.param.u32 %r1, [foo_param_0];
+; PTX64-NEXT:    add.u64 %rd2, %SPL, 0;
+; PTX64-NEXT:    st.local.u32 [%rd2], %r1;
+; PTX64-NEXT:    ret;
   %local = alloca i32, align 4
   store volatile i32 %a, ptr %local
   ret void
 }
 
-; PTX32:        mov.u32          %SPL, __local_depot{{[0-9]+}};
-; PTX32:        cvta.local.u32   %SP, %SPL;
-; PTX32:        ld.param.u32     %r{{[0-9]+}}, [foo2_param_0];
-; PTX32:        add.u32          %r[[SP_REG:[0-9]+]], %SPL, 0;
-; PTX32:        st.local.u32  [%r[[SP_REG]]], %r{{[0-9]+}};
-; PTX64:        mov.u64          %SPL, __local_depot{{[0-9]+}};
-; PTX64:        cvta.local.u64   %SP, %SPL;
-; PTX64:        ld.param.u32     %r{{[0-9]+}}, [foo2_param_0];
-; PTX64:        add.u64          %rd[[SP_REG:[0-9]+]], %SPL, 0;
-; PTX64:        st.local.u32  [%rd[[SP_REG]]], %r{{[0-9]+}};
 define ptx_kernel void @foo2(i32 %a) {
+; PTX32-LABEL: foo2(
+; PTX32:       {
+; PTX32-NEXT:    .local .align 4 .b8 __local_depot1[4];
+; PTX32-NEXT:    .reg .b32 %SP;
+; PTX32-NEXT:    .reg .b32 %SPL;
+; PTX32-NEXT:    .reg .b32 %r<4>;
+; PTX32-EMPTY:
+; PTX32-NEXT:  // %bb.0:
+; PTX32-NEXT:    mov.u32 %SPL, __local_depot1;
+; PTX32-NEXT:    cvta.local.u32 %SP, %SPL;
+; PTX32-NEXT:    ld.param.u32 %r1, [foo2_param_0];
+; PTX32-NEXT:    add.u32 %r2, %SP, 0;
+; PTX32-NEXT:    add.u32 %r3, %SPL, 0;
+; PTX32-NEXT:    st.local.u32 [%r3], %r1;
+; PTX32-NEXT:    { // callseq 0, 0
+; PTX32-NEXT:    .param .b32 param0;
+; PTX32-NEXT:    st.param.b32 [param0], %r2;
+; PTX32-NEXT:    call.uni
+; PTX32-NEXT:    bar,
+; PTX32-NEXT:    (
+; PTX32-NEXT:    param0
+; PTX32-NEXT:    );
+; PTX32-NEXT:    } // callseq 0
+; PTX32-NEXT:    ret;
+;
+; PTX64-LABEL: foo2(
+; PTX64:       {
+; PTX64-NEXT:    .local .align 4 .b8 __local_depot1[4];
+; PTX64-NEXT:    .reg .b64 %SP;
+; PTX64-NEXT:    .reg .b64 %SPL;
+; PTX64-NEXT:    .reg .b32 %r<2>;
+; PTX64-NEXT:    .reg .b64 %rd<3>;
+; PTX64-EMPTY:
+; PTX64-NEXT:  // %bb.0:
+; PTX64-NEXT:    mov.u64 %SPL, __local_depot1;
+; PTX64-NEXT:    cvta.local.u64 %SP, %SPL;
+; PTX64-NEXT:    ld.param.u32 %r1, [foo2_param_0];
+; PTX64-NEXT:    add.u64 %rd1, %SP, 0;
+; PTX64-NEXT:    add.u64 %rd2, %SPL, 0;
+; PTX64-NEXT:    st.local.u32 [%rd2], %r1;
+; PTX64-NEXT:    { // callseq 0, 0
+; PTX64-NEXT:    .param .b64 param0;
+; PTX64-NEXT:    st.param.b64 [param0], %rd1;
+; PTX64-NEXT:    call.uni
+; PTX64-NEXT:    bar,
+; PTX64-NEXT:    (
+; PTX64-NEXT:    param0
+; PTX64-NEXT:    );
+; PTX64-NEXT:    } // callseq 0
+; PTX64-NEXT:    ret;
   %local = alloca i32, align 4
   store i32 %a, ptr %local
   call void @bar(ptr %local)
@@ -38,39 +99,120 @@ define ptx_kernel void @foo2(i32 %a) {
 
 declare void @bar(ptr %a)
 
-
-; PTX32:        mov.u32          %SPL, __local_depot{{[0-9]+}};
-; PTX32-NOT:    cvta.local.u32   %SP, %SPL;
-; PTX32:        ld.param.u32     %r{{[0-9]+}}, [foo3_param_0];
-; PTX32:        add.u32          %r{{[0-9]+}}, %SPL, 0;
-; PTX32:        st.local.u32  [%r{{[0-9]+}}], %r{{[0-9]+}};
-; PTX64:        mov.u64          %SPL, __local_depot{{[0-9]+}};
-; PTX64-NOT:    cvta.local.u64   %SP, %SPL;
-; PTX64:        ld.param.u32     %r{{[0-9]+}}, [foo3_param_0];
-; PTX64:        add.u64          %rd{{[0-9]+}}, %SPL, 0;
-; PTX64:        st.local.u32  [%rd{{[0-9]+}}], %r{{[0-9]+}};
 define void @foo3(i32 %a) {
+; PTX32-LABEL: foo3(
+; PTX32:       {
+; PTX32-NEXT:    .local .align 4 .b8 __local_depot2[12];
+; PTX32-NEXT:    .reg .b32 %SP;
+; PTX32-NEXT:    .reg .b32 %SPL;
+; PTX32-NEXT:    .reg .b32 %r<6>;
+; PTX32-EMPTY:
+; PTX32-NEXT:  // %bb.0:
+; PTX32-NEXT:    mov.u32 %SPL, __local_depot2;
+; PTX32-NEXT:    ld.param.u32 %r1, [foo3_param_0];
+; PTX32-NEXT:    add.u32 %r3, %SPL, 0;
+; PTX32-NEXT:    shl.b32 %r4, %r1, 2;
+; PTX32-NEXT:    add.s32 %r5, %r3, %r4;
+; PTX32-NEXT:    st.local.u32 [%r5], %r1;
+; PTX32-NEXT:    ret;
+;
+; PTX64-LABEL: foo3(
+; PTX64:       {
+; PTX64-NEXT:    .local .align 4 .b8 __local_depot2[12];
+; PTX64-NEXT:    .reg .b64 %SP;
+; PTX64-NEXT:    .reg .b64 %SPL;
+; PTX64-NEXT:    .reg .b32 %r<2>;
+; PTX64-NEXT:    .reg .b64 %rd<5>;
+; PTX64-EMPTY:
+; PTX64-NEXT:  // %bb.0:
+; PTX64-NEXT:    mov.u64 %SPL, __local_depot2;
+; PTX64-NEXT:    ld.param.u32 %r1, [foo3_param_0];
+; PTX64-NEXT:    add.u64 %rd2, %SPL, 0;
+; PTX64-NEXT:    mul.wide.s32 %rd3, %r1, 4;
+; PTX64-NEXT:    add.s64 %rd4, %rd2, %rd3;
+; PTX64-NEXT:    st.local.u32 [%rd4], %r1;
+; PTX64-NEXT:    ret;
   %local = alloca [3 x i32], align 4
   %1 = getelementptr inbounds i32, ptr %local, i32 %a
   store i32 %a, ptr %1
   ret void
 }
 
-; PTX32:        cvta.local.u32   %SP, %SPL;
-; PTX32:        add.u32          {{%r[0-9]+}}, %SP, 0;
-; PTX32:        add.u32          {{%r[0-9]+}}, %SPL, 0;
-; PTX32:        add.u32          {{%r[0-9]+}}, %SP, 4;
-; PTX32:        add.u32          {{%r[0-9]+}}, %SPL, 4;
-; PTX32:        st.local.u32     [{{%r[0-9]+}}], {{%r[0-9]+}}
-; PTX32:        st.local.u32     [{{%r[0-9]+}}], {{%r[0-9]+}}
-; PTX64:        cvta.local.u64   %SP, %SPL;
-; PTX64:        add.u64          {{%rd[0-9]+}}, %SP, 0;
-; PTX64:        add.u64          {{%rd[0-9]+}}, %SPL, 0;
-; PTX64:        add.u64          {{%rd[0-9]+}}, %SP, 4;
-; PTX64:        add.u64          {{%rd[0-9]+}}, %SPL, 4;
-; PTX64:        st.local.u32     [{{%rd[0-9]+}}], {{%r[0-9]+}}
-; PTX64:        st.local.u32     [{{%rd[0-9]+}}], {{%r[0-9]+}}
 define void @foo4() {
+; PTX32-LABEL: foo4(
+; PTX32:       {
+; PTX32-NEXT:    .local .align 4 .b8 __local_depot3[8];
+; PTX32-NEXT:    .reg .b32 %SP;
+; PTX32-NEXT:    .reg .b32 %SPL;
+; PTX32-NEXT:    .reg .b32 %r<6>;
+; PTX32-EMPTY:
+; PTX32-NEXT:  // %bb.0:
+; PTX32-NEXT:    mov.u32 %SPL, __local_depot3;
+; PTX32-NEXT:    cvta.local.u32 %SP, %SPL;
+; PTX32-NEXT:    add.u32 %r1, %SP, 0;
+; PTX32-NEXT:    add.u32 %r2, %SPL, 0;
+; PTX32-NEXT:    add.u32 %r3, %SP, 4;
+; PTX32-NEXT:    add.u32 %r4, %SPL, 4;
+; PTX32-NEXT:    mov.b32 %r5, 0;
+; PTX32-NEXT:    st.local.u32 [%r2], %r5;
+; PTX32-NEXT:    st.local.u32 [%r4], %r5;
+; PTX32-NEXT:    { // callseq 1, 0
+; PTX32-NEXT:    .param .b32 param0;
+; PTX32-NEXT:    st.param.b32 [param0], %r1;
+; PTX32-NEXT:    call.uni
+; PTX32-NEXT:    bar,
+; PTX32-NEXT:    (
+; PTX32-NEXT:    param0
+; PTX32-NEXT:    );
+; PTX32-NEXT:    } // callseq 1
+; PTX32-NEXT:    { // callseq 2, 0
+; PTX32-NEXT:    .param .b32 param0;
+; PTX32-NEXT:    st.param.b32 [param0], %r3;
+; PTX32-NEXT:    call.uni
+; PTX32-NEXT:    bar,
+; PTX32-NEXT:    (
+; PTX32-NEXT:    param0
+; PTX32-NEXT:    );
+; PTX32-NEXT:    } // callseq 2
+; PTX32-NEXT:    ret;
+;
+; PTX64-LABEL: foo4(
+; PTX64:       {
+; PTX64-NEXT:    .local .align 4 .b8 __local_depot3[8];
+; PTX64-NEXT:    .reg .b64 %SP;
+; PTX64-NEXT:    .reg .b64 %SPL;
+; PTX64-NEXT:    .reg .b32 %r<2>;
+; PTX64-NEXT:    .reg .b64 %rd<5>;
+; PTX64-EMPTY:
+; PTX64-NEXT:  // %bb.0:
+; PTX64-NEXT:    mov.u64 %SPL, __local_depot3;
+; PTX64-NEXT:    cvta.local.u64 %SP, %SPL;
+; PTX64-NEXT:    add.u64 %rd1, %SP, 0;
+; PTX64-NEXT:    add.u64 %rd2, %SPL, 0;
+; PTX64-NEXT:    add.u64 %rd3, %SP, 4;
+; PTX64-NEXT:    add.u64 %rd4, %SPL, 4;
+; PTX64-NEXT:    mov.b32 %r1, 0;
+; PTX64-NEXT:    st.local.u32 [%rd2], %r1;
+; PTX64-NEXT:    st.local.u32 [%rd4], %r1;
+; PTX64-NEXT:    { // callseq 1, 0
+; PTX64-NEXT:    .param .b64 param0;
+; PTX64-NEXT:    st.param.b64 [param0], %rd1;
+; PTX64-NEXT:    call.uni
+; PTX64-NEXT:    bar,
+; PTX64-NEXT:    (
+; PTX64-NEXT:    param0
+; PTX64-NEXT:    );
+; PTX64-NEXT:    } // callseq 1
+; PTX64-NEXT:    { // callseq 2, 0
+; PTX64-NEXT:    .param .b64 param0;
+; PTX64-NEXT:    st.param.b64 [param0], %rd3;
+; PTX64-NEXT:    call.uni
+; PTX64-NEXT:    bar,
+; PTX64-NEXT:    (
+; PTX64-NEXT:    param0
+; PTX64-NEXT:    );
+; PTX64-NEXT:    } // callseq 2
+; PTX64-NEXT:    ret;
   %A = alloca i32
   %B = alloca i32
   store i32 0, ptr %A

>From 563d80d7545078fe4a3e08698291a19852be3a35 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Tue, 28 Jan 2025 00:45:07 +0000
Subject: [PATCH 4/4] address comments

---
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 11 ++---------
 1 file changed, 2 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index da471d5c3c42602..6c9f11fa1bcf8ba 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -362,15 +362,8 @@ static std::optional<unsigned> convertAS(unsigned AS) {
 }
 
 static unsigned int getCodeAddrSpace(const MemSDNode *N) {
-  if (const Value *Src = N->getMemOperand()->getValue())
-    if (auto *PT = dyn_cast<PointerType>(Src->getType()))
-      if (auto AS = convertAS(PT->getAddressSpace()))
-        return AS.value();
-
-  if (auto AS = convertAS(N->getMemOperand()->getAddrSpace()))
-    return AS.value();
-
-  return NVPTX::AddressSpace::Generic;
+  return convertAS(N->getMemOperand()->getAddrSpace())
+      .value_or(NVPTX::AddressSpace::Generic);
 }
 
 namespace {



More information about the llvm-commits mailing list