[llvm] [NVPTX] Annotate CUDA kernel pointer arguments with .ptr .space .align attributes. (PR #79646)

via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 1 19:45:10 PST 2024


https://github.com/Vandana2896 updated https://github.com/llvm/llvm-project/pull/79646

>From d5bd0215f22440e10e1a4af2b4391973831795be Mon Sep 17 00:00:00 2001
From: Vandana <vandanak at nvidia.com>
Date: Fri, 26 Jan 2024 13:03:27 -0800
Subject: [PATCH 1/2] Enable .ptr .global .align attributes for kernel
 attributes for CUDA

---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp     |  4 +++
 llvm/test/CodeGen/NVPTX/kernel-param-align.ll | 34 +++++++++++++++++++
 2 files changed, 38 insertions(+)
 create mode 100644 llvm/test/CodeGen/NVPTX/kernel-param-align.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 6c4879ba183c0..0a0fbff2ad6c1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1610,6 +1610,10 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
             }
             Align ParamAlign = I->getParamAlign().valueOrOne();
             O << ".align " << ParamAlign.value() << " ";
+          } else if (I->getParamAlign().valueOrOne() != 1) {
+            O << ".ptr .global ";
+            Align ParamAlign = I->getParamAlign().value();
+            O << ".align " << ParamAlign.value() << " ";
           }
           O << TLI->getParamName(F, paramIndex);
           continue;
diff --git a/llvm/test/CodeGen/NVPTX/kernel-param-align.ll b/llvm/test/CodeGen/NVPTX/kernel-param-align.ll
new file mode 100644
index 0000000000000..eda45928ea305
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/kernel-param-align.ll
@@ -0,0 +1,34 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_72 2>&1 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_72 | %ptxas-verify %}
+
+%struct.Large = type { [16 x double] }
+
+; CHECK: .param .u64 .ptr .global .align 16 func_align_param_0,
+; CHECK: .param .u64 func_align_param_1,
+; CHECK: .param .u32 func_align_param_2
+define void @func_align(ptr nocapture readonly align 16 %input, ptr nocapture %out, i32 %n) {
+entry:
+  %0 = addrspacecast ptr %out to ptr addrspace(1)
+  %1 = addrspacecast ptr %input to ptr addrspace(1)
+  %getElem = getelementptr inbounds %struct.Large, ptr addrspace(1) %1, i64 0, i32 0, i64 5
+  %tmp2 = load i32, ptr addrspace(1) %getElem, align 8
+  store i32 %tmp2, ptr addrspace(1) %0, align 4
+  ret void
+}
+
+; CHECK: .param .u64 func_param_0,
+; CHECK: .param .u64 func_param_1,
+; CHECK: .param .u32 func_param_2
+define void @func(ptr nocapture readonly %input, ptr nocapture %out, i32 %n) {
+entry:
+  %0 = addrspacecast ptr %out to ptr addrspace(1)
+  %1 = addrspacecast ptr %input to ptr addrspace(1)
+  %getElem = getelementptr inbounds %struct.Large, ptr addrspace(1) %1, i64 0, i32 0, i64 5
+  %tmp2 = load i32, ptr addrspace(1) %getElem, align 8
+  store i32 %tmp2, ptr addrspace(1) %0, align 4
+  ret void
+}
+
+!nvvm.annotations = !{!0, !1}
+!0 = !{ptr @func_align, !"kernel", i32 1}
+!1 = !{ptr @func, !"kernel", i32 1}
\ No newline at end of file

>From a952cc1c655125829e49d11b46d7979b9052a671 Mon Sep 17 00:00:00 2001
From: Vandana <vandanak at nvidia.com>
Date: Thu, 1 Feb 2024 19:44:36 -0800
Subject: [PATCH 2/2] Rearrange code, add comment

---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 19 ++++++++++---------
 1 file changed, 10 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 0a0fbff2ad6c1..4d2da1f56050c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1589,11 +1589,16 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
       if (isKernelFunc) {
         if (PTy) {
           // Special handling for pointer arguments to kernel
+          // CUDA kernels assume that pointers are in global address space
+          // See:
+          // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
           O << "\t.param .u" << PTySizeInBits << " ";
 
-          if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
-              NVPTX::CUDA) {
-            int addrSpace = PTy->getAddressSpace();
+          int addrSpace = PTy->getAddressSpace();
+          if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
+            assert(addrSpace == 0 && "Invalid address space");
+            O << ".ptr .global ";
+          } else {
             switch (addrSpace) {
             default:
               O << ".ptr ";
@@ -1608,13 +1613,9 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
               O << ".ptr .global ";
               break;
             }
-            Align ParamAlign = I->getParamAlign().valueOrOne();
-            O << ".align " << ParamAlign.value() << " ";
-          } else if (I->getParamAlign().valueOrOne() != 1) {
-            O << ".ptr .global ";
-            Align ParamAlign = I->getParamAlign().value();
-            O << ".align " << ParamAlign.value() << " ";
           }
+          Align ParamAlign = I->getParamAlign().valueOrOne();
+          O << ".align " << ParamAlign.value() << " ";
           O << TLI->getParamName(F, paramIndex);
           continue;
         }



More information about the llvm-commits mailing list