[llvm] [NVPTX] restrict `cvta.param` use to kernels only. (PR #112278)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 14 15:12:08 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Artem Belevich (Artem-B)

<details>
<summary>Changes</summary>

If cvta.param is used in regular functions, `cvta.param` may produce an invalid pointer.

---
Full diff: https://github.com/llvm/llvm-project/pull/112278.diff


2 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp (+2-1) 
- (modified) llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll (+75-7) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index 4a184037add4ce..3041c16c7a7604 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -545,7 +545,8 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
 void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
                                       Argument *Arg) {
   Function *Func = Arg->getParent();
-  bool HasCvtaParam = TM.getSubtargetImpl(*Func)->hasCvtaParam();
+  bool HasCvtaParam =
+      TM.getSubtargetImpl(*Func)->hasCvtaParam() && isKernelFunction(*Func);
   bool IsGridConstant = HasCvtaParam && isParamGridConstant(*Arg);
   const DataLayout &DL = Func->getDataLayout();
   BasicBlock::iterator FirstInst = Func->getEntryBlock().begin();
diff --git a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
index b203a78d677308..33fa3afc94b89d 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
@@ -2,6 +2,72 @@
 ; RUN: opt < %s -S -nvptx-lower-args --mtriple nvptx64-nvidia-cuda -mcpu=sm_70 -mattr=+ptx77 | FileCheck %s --check-prefixes OPT
 ; RUN: llc < %s --mtriple nvptx64-nvidia-cuda -mcpu=sm_70 -mattr=+ptx77 | FileCheck %s --check-prefixes PTX
 
+%struct.uint4 = type { i32, i32, i32, i32 }
+
+ at gi = dso_local addrspace(1) externally_initialized global %struct.uint4 { i32 50462976, i32 117835012, i32 185207048, i32 252579084 }, align 16
+
+; Function Attrs: mustprogress nofree noinline norecurse nosync nounwind willreturn memory(read, inaccessiblemem: none)
+; Regular functions mus still make a copy. `cvta.param` does not always work there.
+define dso_local noundef i32 @non_kernel_function(ptr nocapture noundef readonly byval(%struct.uint4) align 16 %a, i1 noundef zeroext %b, i32 noundef %c) local_unnamed_addr #0 {
+; OPT-LABEL: define dso_local noundef i32 @non_kernel_function(
+; OPT-SAME: ptr nocapture noundef readonly byval([[STRUCT_UINT4:%.*]]) align 16 [[A:%.*]], i1 noundef zeroext [[B:%.*]], i32 noundef [[C:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+; OPT-NEXT:  [[ENTRY:.*:]]
+; OPT-NEXT:    [[A1:%.*]] = alloca [[STRUCT_UINT4]], align 16
+; OPT-NEXT:    [[A2:%.*]] = addrspacecast ptr [[A]] to ptr addrspace(101)
+; OPT-NEXT:    call void @llvm.memcpy.p0.p101.i64(ptr align 16 [[A1]], ptr addrspace(101) align 16 [[A2]], i64 16, i1 false)
+; OPT-NEXT:    [[A_:%.*]] = select i1 [[B]], ptr [[A1]], ptr addrspacecast (ptr addrspace(1) @gi to ptr)
+; OPT-NEXT:    [[IDX_EXT:%.*]] = sext i32 [[C]] to i64
+; OPT-NEXT:    [[ADD_PTR:%.*]] = getelementptr inbounds i8, ptr [[A_]], i64 [[IDX_EXT]]
+; OPT-NEXT:    [[TMP0:%.*]] = load i32, ptr [[ADD_PTR]], align 1
+; OPT-NEXT:    ret i32 [[TMP0]]
+;
+; PTX-LABEL: non_kernel_function(
+; PTX:       {
+; PTX-NEXT:    .local .align 16 .b8 __local_depot0[16];
+; PTX-NEXT:    .reg .b64 %SP;
+; PTX-NEXT:    .reg .b64 %SPL;
+; PTX-NEXT:    .reg .pred %p<2>;
+; PTX-NEXT:    .reg .b16 %rs<3>;
+; PTX-NEXT:    .reg .b32 %r<11>;
+; PTX-NEXT:    .reg .b64 %rd<10>;
+; PTX-EMPTY:
+; PTX-NEXT:  // %bb.0: // %entry
+; PTX-NEXT:    mov.u64 %SPL, __local_depot0;
+; PTX-NEXT:    cvta.local.u64 %SP, %SPL;
+; PTX-NEXT:    ld.param.u8 %rs1, [non_kernel_function_param_1];
+; PTX-NEXT:    and.b16 %rs2, %rs1, 1;
+; PTX-NEXT:    setp.eq.b16 %p1, %rs2, 1;
+; PTX-NEXT:    ld.param.s32 %rd1, [non_kernel_function_param_2];
+; PTX-NEXT:    add.u64 %rd2, %SP, 0;
+; PTX-NEXT:    or.b64 %rd3, %rd2, 8;
+; PTX-NEXT:    ld.param.u64 %rd4, [non_kernel_function_param_0+8];
+; PTX-NEXT:    st.u64 [%rd3], %rd4;
+; PTX-NEXT:    ld.param.u64 %rd5, [non_kernel_function_param_0];
+; PTX-NEXT:    st.u64 [%SP+0], %rd5;
+; PTX-NEXT:    mov.u64 %rd6, gi;
+; PTX-NEXT:    cvta.global.u64 %rd7, %rd6;
+; PTX-NEXT:    selp.b64 %rd8, %rd2, %rd7, %p1;
+; PTX-NEXT:    add.s64 %rd9, %rd8, %rd1;
+; PTX-NEXT:    ld.u8 %r1, [%rd9];
+; PTX-NEXT:    ld.u8 %r2, [%rd9+1];
+; PTX-NEXT:    shl.b32 %r3, %r2, 8;
+; PTX-NEXT:    or.b32 %r4, %r3, %r1;
+; PTX-NEXT:    ld.u8 %r5, [%rd9+2];
+; PTX-NEXT:    shl.b32 %r6, %r5, 16;
+; PTX-NEXT:    ld.u8 %r7, [%rd9+3];
+; PTX-NEXT:    shl.b32 %r8, %r7, 24;
+; PTX-NEXT:    or.b32 %r9, %r8, %r6;
+; PTX-NEXT:    or.b32 %r10, %r9, %r4;
+; PTX-NEXT:    st.param.b32 [func_retval0+0], %r10;
+; PTX-NEXT:    ret;
+entry:
+  %a. = select i1 %b, ptr %a, ptr addrspacecast (ptr addrspace(1) @gi to ptr), !dbg !17
+  %idx.ext = sext i32 %c to i64, !dbg !18
+  %add.ptr = getelementptr inbounds i8, ptr %a., i64 %idx.ext, !dbg !18
+  %0 = load i32, ptr %add.ptr, align 1, !dbg !19
+  ret i32 %0, !dbg !23
+}
+
 define void @grid_const_int(ptr byval(i32) align 4 %input1, i32 %input2, ptr %out, i32 %n) {
 ; PTX-LABEL: grid_const_int(
 ; PTX:       {
@@ -17,7 +83,7 @@ define void @grid_const_int(ptr byval(i32) align 4 %input1, i32 %input2, ptr %ou
 ; PTX-NEXT:    st.global.u32 [%rd2], %r3;
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define void @grid_const_int(
-; OPT-SAME: ptr byval(i32) align 4 [[INPUT1:%.*]], i32 [[INPUT2:%.*]], ptr [[OUT:%.*]], i32 [[N:%.*]]) #[[ATTR0:[0-9]+]] {
+; OPT-SAME: ptr byval(i32) align 4 [[INPUT1:%.*]], i32 [[INPUT2:%.*]], ptr [[OUT:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
 ; OPT-NEXT:    [[OUT2:%.*]] = addrspacecast ptr [[OUT]] to ptr addrspace(1)
 ; OPT-NEXT:    [[OUT3:%.*]] = addrspacecast ptr addrspace(1) [[OUT2]] to ptr
 ; OPT-NEXT:    [[INPUT11:%.*]] = addrspacecast ptr [[INPUT1]] to ptr addrspace(101)
@@ -106,14 +172,14 @@ define void @grid_const_escape(ptr byval(%struct.s) align 4 %input) {
 define void @multiple_grid_const_escape(ptr byval(%struct.s) align 4 %input, i32 %a, ptr byval(i32) align 4 %b) {
 ; PTX-LABEL: multiple_grid_const_escape(
 ; PTX:       {
-; PTX-NEXT:    .local .align 4 .b8 __local_depot3[4];
+; PTX-NEXT:    .local .align 4 .b8 __local_depot4[4];
 ; PTX-NEXT:    .reg .b64 %SP;
 ; PTX-NEXT:    .reg .b64 %SPL;
 ; PTX-NEXT:    .reg .b32 %r<4>;
 ; PTX-NEXT:    .reg .b64 %rd<10>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
-; PTX-NEXT:    mov.u64 %SPL, __local_depot3;
+; PTX-NEXT:    mov.u64 %SPL, __local_depot4;
 ; PTX-NEXT:    cvta.local.u64 %SP, %SPL;
 ; PTX-NEXT:    mov.b64 %rd2, multiple_grid_const_escape_param_0;
 ; PTX-NEXT:    mov.b64 %rd3, multiple_grid_const_escape_param_2;
@@ -342,10 +408,10 @@ define void @grid_const_phi(ptr byval(%struct.s) align 4 %input1, ptr %inout) {
 ; PTX-NEXT:    cvta.param.u64 %rd8, %rd7;
 ; PTX-NEXT:    ld.global.u32 %r1, [%rd1];
 ; PTX-NEXT:    setp.lt.s32 %p1, %r1, 0;
-; PTX-NEXT:    @%p1 bra $L__BB8_2;
+; PTX-NEXT:    @%p1 bra $L__BB9_2;
 ; PTX-NEXT:  // %bb.1: // %second
 ; PTX-NEXT:    add.s64 %rd8, %rd8, 4;
-; PTX-NEXT:  $L__BB8_2: // %merge
+; PTX-NEXT:  $L__BB9_2: // %merge
 ; PTX-NEXT:    ld.u32 %r2, [%rd8];
 ; PTX-NEXT:    st.global.u32 [%rd1], %r2;
 ; PTX-NEXT:    ret;
@@ -402,13 +468,13 @@ define void @grid_const_phi_ngc(ptr byval(%struct.s) align 4 %input1, ptr byval(
 ; PTX-NEXT:    cvta.param.u64 %rd11, %rd10;
 ; PTX-NEXT:    ld.global.u32 %r1, [%rd1];
 ; PTX-NEXT:    setp.lt.s32 %p1, %r1, 0;
-; PTX-NEXT:    @%p1 bra $L__BB9_2;
+; PTX-NEXT:    @%p1 bra $L__BB10_2;
 ; PTX-NEXT:  // %bb.1: // %second
 ; PTX-NEXT:    mov.b64 %rd8, grid_const_phi_ngc_param_1;
 ; PTX-NEXT:    mov.u64 %rd9, %rd8;
 ; PTX-NEXT:    cvta.param.u64 %rd2, %rd9;
 ; PTX-NEXT:    add.s64 %rd11, %rd2, 4;
-; PTX-NEXT:  $L__BB9_2: // %merge
+; PTX-NEXT:  $L__BB10_2: // %merge
 ; PTX-NEXT:    ld.u32 %r2, [%rd11];
 ; PTX-NEXT:    st.global.u32 [%rd1], %r2;
 ; PTX-NEXT:    ret;
@@ -567,3 +633,5 @@ declare dso_local ptr @escape3(ptr, ptr, ptr) local_unnamed_addr
 
 !22 = !{ptr @grid_const_ptrtoint, !"kernel", i32 1, !"grid_constant", !23}
 !23 = !{i32 1}
+
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/112278


More information about the llvm-commits mailing list