[llvm] Enable .ptr .global .align attributes for kernel attributes for CUDA (PR #114874)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 7 11:23:41 PST 2024


================
@@ -1600,29 +1600,37 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
 
       if (isKernelFunc) {
         if (PTy) {
-          // Special handling for pointer arguments to kernel
           O << "\t.param .u" << PTySizeInBits << " ";
 
-          if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
-              NVPTX::CUDA) {
-            int addrSpace = PTy->getAddressSpace();
-            switch (addrSpace) {
-            default:
-              O << ".ptr ";
-              break;
-            case ADDRESS_SPACE_CONST:
-              O << ".ptr .const ";
-              break;
-            case ADDRESS_SPACE_SHARED:
-              O << ".ptr .shared ";
-              break;
-            case ADDRESS_SPACE_GLOBAL:
-              O << ".ptr .global ";
-              break;
-            }
-            Align ParamAlign = I->getParamAlign().valueOrOne();
-            O << ".align " << ParamAlign.value() << " ";
+          int addrSpace = PTy->getAddressSpace();
+          const bool IsCUDA =
+              static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() ==
+              NVPTX::CUDA;
+
+          O << ".ptr ";
+          switch (addrSpace) {
+          default:
+            // Special handling for pointer arguments to kernel
+            // CUDA kernels assume that pointers are in global address space
+            // See:
+            // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
+            if (IsCUDA)
+              O << " .global ";
+            break;
+          case ADDRESS_SPACE_CONST:
+            O << " .const ";
+            break;
+          case ADDRESS_SPACE_SHARED:
+            O << " .shared ";
+            break;
+          case ADDRESS_SPACE_GLOBAL:
+            O << " .global ";
+            break;
           }
+
+          Align ParamAlign = I->getParamAlign().valueOrOne();
+          if (ParamAlign != 1 || !IsCUDA)
----------------
AlexMaclean wrote:

I'm a little unsure about why it is safe to implicitly assume 4 for CUDA. What are the semantics for LLVM IR when no alignment is specified? Shouldn't we faithfully translate these into PTX?

https://github.com/llvm/llvm-project/pull/114874


More information about the llvm-commits mailing list