[llvm] [NVPTX] Annotate CUDA kernel pointer arguments with .ptr .space .align attributes. (PR #79646)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 25 07:31:09 PDT 2024


https://github.com/Vandana2896 updated https://github.com/llvm/llvm-project/pull/79646

>From d5bd0215f22440e10e1a4af2b4391973831795be Mon Sep 17 00:00:00 2001
From: Vandana <vandanak at nvidia.com>
Date: Fri, 26 Jan 2024 13:03:27 -0800
Subject: [PATCH 1/8] Enable .ptr .global .align attributes for kernel
 attributes for CUDA

---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp     |  4 +++
 llvm/test/CodeGen/NVPTX/kernel-param-align.ll | 34 +++++++++++++++++++
 2 files changed, 38 insertions(+)
 create mode 100644 llvm/test/CodeGen/NVPTX/kernel-param-align.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 6c4879ba183c0a..0a0fbff2ad6c11 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1610,6 +1610,10 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
             }
             Align ParamAlign = I->getParamAlign().valueOrOne();
             O << ".align " << ParamAlign.value() << " ";
+          } else if (I->getParamAlign().valueOrOne() != 1) {
+            O << ".ptr .global ";
+            Align ParamAlign = I->getParamAlign().value();
+            O << ".align " << ParamAlign.value() << " ";
           }
           O << TLI->getParamName(F, paramIndex);
           continue;
diff --git a/llvm/test/CodeGen/NVPTX/kernel-param-align.ll b/llvm/test/CodeGen/NVPTX/kernel-param-align.ll
new file mode 100644
index 00000000000000..eda45928ea3059
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/kernel-param-align.ll
@@ -0,0 +1,34 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_72 2>&1 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_72 | %ptxas-verify %}
+
+%struct.Large = type { [16 x double] }
+
+; CHECK: .param .u64 .ptr .global .align 16 func_align_param_0,
+; CHECK: .param .u64 func_align_param_1,
+; CHECK: .param .u32 func_align_param_2
+define void @func_align(ptr nocapture readonly align 16 %input, ptr nocapture %out, i32 %n) {
+entry:
+  %0 = addrspacecast ptr %out to ptr addrspace(1)
+  %1 = addrspacecast ptr %input to ptr addrspace(1)
+  %getElem = getelementptr inbounds %struct.Large, ptr addrspace(1) %1, i64 0, i32 0, i64 5
+  %tmp2 = load i32, ptr addrspace(1) %getElem, align 8
+  store i32 %tmp2, ptr addrspace(1) %0, align 4
+  ret void
+}
+
+; CHECK: .param .u64 func_param_0,
+; CHECK: .param .u64 func_param_1,
+; CHECK: .param .u32 func_param_2
+define void @func(ptr nocapture readonly %input, ptr nocapture %out, i32 %n) {
+entry:
+  %0 = addrspacecast ptr %out to ptr addrspace(1)
+  %1 = addrspacecast ptr %input to ptr addrspace(1)
+  %getElem = getelementptr inbounds %struct.Large, ptr addrspace(1) %1, i64 0, i32 0, i64 5
+  %tmp2 = load i32, ptr addrspace(1) %getElem, align 8
+  store i32 %tmp2, ptr addrspace(1) %0, align 4
+  ret void
+}
+
+!nvvm.annotations = !{!0, !1}
+!0 = !{ptr @func_align, !"kernel", i32 1}
+!1 = !{ptr @func, !"kernel", i32 1}
\ No newline at end of file

>From a952cc1c655125829e49d11b46d7979b9052a671 Mon Sep 17 00:00:00 2001
From: Vandana <vandanak at nvidia.com>
Date: Thu, 1 Feb 2024 19:44:36 -0800
Subject: [PATCH 2/8] Rearrange code, add comment

---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 19 ++++++++++---------
 1 file changed, 10 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 0a0fbff2ad6c11..4d2da1f56050c3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1589,11 +1589,16 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
       if (isKernelFunc) {
         if (PTy) {
           // Special handling for pointer arguments to kernel
+          // CUDA kernels assume that pointers are in global address space
+          // See:
+          // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
           O << "\t.param .u" << PTySizeInBits << " ";
 
-          if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
-              NVPTX::CUDA) {
-            int addrSpace = PTy->getAddressSpace();
+          int addrSpace = PTy->getAddressSpace();
+          if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
+            assert(addrSpace == 0 && "Invalid address space");
+            O << ".ptr .global ";
+          } else {
             switch (addrSpace) {
             default:
               O << ".ptr ";
@@ -1608,13 +1613,9 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
               O << ".ptr .global ";
               break;
             }
-            Align ParamAlign = I->getParamAlign().valueOrOne();
-            O << ".align " << ParamAlign.value() << " ";
-          } else if (I->getParamAlign().valueOrOne() != 1) {
-            O << ".ptr .global ";
-            Align ParamAlign = I->getParamAlign().value();
-            O << ".align " << ParamAlign.value() << " ";
           }
+          Align ParamAlign = I->getParamAlign().valueOrOne();
+          O << ".align " << ParamAlign.value() << " ";
           O << TLI->getParamName(F, paramIndex);
           continue;
         }

>From f5e72769cf3357c4ac18eb8295498d54d08343ad Mon Sep 17 00:00:00 2001
From: Vandana <vandanak at nvidia.com>
Date: Thu, 1 Feb 2024 19:50:49 -0800
Subject: [PATCH 3/8] Fixed clang formatting

---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 4d2da1f56050c3..f7e5ce2e0f85a3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1595,7 +1595,8 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
           O << "\t.param .u" << PTySizeInBits << " ";
 
           int addrSpace = PTy->getAddressSpace();
-          if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
+          if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() ==
+              NVPTX::CUDA) {
             assert(addrSpace == 0 && "Invalid address space");
             O << ".ptr .global ";
           } else {

>From 761d8a0aeeef340948181bc92ac0f290aaccb051 Mon Sep 17 00:00:00 2001
From: Vandana2896 <129426835+Vandana2896 at users.noreply.github.com>
Date: Tue, 20 Feb 2024 14:50:55 -0800
Subject: [PATCH 4/8] Update NVPTXAsmPrinter.cpp

Clang formatting reverted.
---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index f7e5ce2e0f85a3..4d2da1f56050c3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1595,8 +1595,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
           O << "\t.param .u" << PTySizeInBits << " ";
 
           int addrSpace = PTy->getAddressSpace();
-          if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() ==
-              NVPTX::CUDA) {
+          if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
             assert(addrSpace == 0 && "Invalid address space");
             O << ".ptr .global ";
           } else {

>From 3d49f303bf577c7ef5e308d30d92236a0899e7c8 Mon Sep 17 00:00:00 2001
From: Vandana2896 <129426835+Vandana2896 at users.noreply.github.com>
Date: Tue, 20 Feb 2024 15:07:09 -0800
Subject: [PATCH 5/8] Update NVPTXAsmPrinter.cpp

---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 4d2da1f56050c3..f7e5ce2e0f85a3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1595,7 +1595,8 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
           O << "\t.param .u" << PTySizeInBits << " ";
 
           int addrSpace = PTy->getAddressSpace();
-          if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
+          if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() ==
+              NVPTX::CUDA) {
             assert(addrSpace == 0 && "Invalid address space");
             O << ".ptr .global ";
           } else {

>From 424667b4efa95fb996a7171f307ade2447eb86ec Mon Sep 17 00:00:00 2001
From: Vandana <vandanak at nvidia.com>
Date: Mon, 11 Mar 2024 04:28:12 -0700
Subject: [PATCH 6/8] Update .global and .align

---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp     | 12 ++++++++++--
 llvm/test/CodeGen/NVPTX/kernel-param-align.ll |  6 +++---
 2 files changed, 13 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index f7e5ce2e0f85a3..3388edc8802fe8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1597,8 +1597,16 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
           int addrSpace = PTy->getAddressSpace();
           if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() ==
               NVPTX::CUDA) {
+            // Special handling for pointer arguments to kernel
+            // CUDA kernels assume that pointers are in global address space
+            // See:
+            // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
             assert(addrSpace == 0 && "Invalid address space");
             O << ".ptr .global ";
+            if (I->getParamAlign().valueOrOne() != 1) {
+              Align ParamAlign = I->getParamAlign().value();
+              O << ".align " << ParamAlign.value() << " ";
+            }
           } else {
             switch (addrSpace) {
             default:
@@ -1614,9 +1622,9 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
               O << ".ptr .global ";
               break;
             }
+            Align ParamAlign = I->getParamAlign().valueOrOne();
+            O << ".align " << ParamAlign.value() << " ";
           }
-          Align ParamAlign = I->getParamAlign().valueOrOne();
-          O << ".align " << ParamAlign.value() << " ";
           O << TLI->getParamName(F, paramIndex);
           continue;
         }
diff --git a/llvm/test/CodeGen/NVPTX/kernel-param-align.ll b/llvm/test/CodeGen/NVPTX/kernel-param-align.ll
index eda45928ea3059..81446a78bfc73b 100644
--- a/llvm/test/CodeGen/NVPTX/kernel-param-align.ll
+++ b/llvm/test/CodeGen/NVPTX/kernel-param-align.ll
@@ -4,7 +4,7 @@
 %struct.Large = type { [16 x double] }
 
 ; CHECK: .param .u64 .ptr .global .align 16 func_align_param_0,
-; CHECK: .param .u64 func_align_param_1,
+; CHECK: .param .u64 .ptr .global func_align_param_1,
 ; CHECK: .param .u32 func_align_param_2
 define void @func_align(ptr nocapture readonly align 16 %input, ptr nocapture %out, i32 %n) {
 entry:
@@ -16,8 +16,8 @@ entry:
   ret void
 }
 
-; CHECK: .param .u64 func_param_0,
-; CHECK: .param .u64 func_param_1,
+; CHECK: .param .ptr .global .u64 func_param_0,
+; CHECK: .param .ptr .global .u64 func_param_1,
 ; CHECK: .param .u32 func_param_2
 define void @func(ptr nocapture readonly %input, ptr nocapture %out, i32 %n) {
 entry:

>From 391221174f7d98bbca8aeecf80d430ade1c41635 Mon Sep 17 00:00:00 2001
From: Vandana <vandanak at nvidia.com>
Date: Mon, 11 Mar 2024 04:32:04 -0700
Subject: [PATCH 7/8] Fix comment

---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 3388edc8802fe8..ebdca299811572 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1588,10 +1588,6 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
 
       if (isKernelFunc) {
         if (PTy) {
-          // Special handling for pointer arguments to kernel
-          // CUDA kernels assume that pointers are in global address space
-          // See:
-          // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
           O << "\t.param .u" << PTySizeInBits << " ";
 
           int addrSpace = PTy->getAddressSpace();

>From ce29dc1282d6ed5e9540c6b8e003c2dda01a4d0f Mon Sep 17 00:00:00 2001
From: Vandana <vandanak at nvidia.com>
Date: Thu, 25 Apr 2024 07:30:51 -0700
Subject: [PATCH 8/8] add addrspace

---
 llvm/test/CodeGen/NVPTX/kernel-param-align.ll | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/test/CodeGen/NVPTX/kernel-param-align.ll b/llvm/test/CodeGen/NVPTX/kernel-param-align.ll
index 81446a78bfc73b..bc6d7844dc79e9 100644
--- a/llvm/test/CodeGen/NVPTX/kernel-param-align.ll
+++ b/llvm/test/CodeGen/NVPTX/kernel-param-align.ll
@@ -5,8 +5,8 @@
 
 ; CHECK: .param .u64 .ptr .global .align 16 func_align_param_0,
 ; CHECK: .param .u64 .ptr .global func_align_param_1,
-; CHECK: .param .u32 func_align_param_2
-define void @func_align(ptr nocapture readonly align 16 %input, ptr nocapture %out, i32 %n) {
+; CHECK: .param .u32 .ptr .global func_align_param_2
+define void @func_align(ptr nocapture readonly align 16 %input, ptr addrspace(3) nocapture %out, i32 %n) {
 entry:
   %0 = addrspacecast ptr %out to ptr addrspace(1)
   %1 = addrspacecast ptr %input to ptr addrspace(1)



More information about the llvm-commits mailing list