[llvm] [NVPTX] Don't use underlying alignment to align param (PR #96793)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 26 09:23:58 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Hugh Delaney (hdelan)

<details>
<summary>Changes</summary>

Previously, if a ptr had align N, then the NVPTX lowering was taking this align N to refer to the alignment of the pointer type itself, as opposed to the alignment of the memory that it points to.

As such, if a kernel of the form:

```
define void @<!-- -->foo(ptr align 4 %_arg_ptr)
```

Would take align 4 to be the alignment of the parameter, which would result in breaking the ld.param into two separate loads.

```
	ld.param.u32 	%rd1, [foo_param_0+4];
	shl.b64 	%rd2, %rd1, 32;
	ld.param.u32 	%rd3, [foo_param_0];
	or.b64  	%rd4, %rd2, %rd3;
```

It isn't necessary as far as I can tell from the PTX ISA documents to specify the alignment of the parameters themselves. So this patch changes the codegen to the better:

```
	ld.param.u64 	%rd1, [foo_param_0];
```

Ping @<!-- -->frasercrmck @<!-- -->ldrumm 

---
Full diff: https://github.com/llvm/llvm-project/pull/96793.diff


2 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+1-3) 
- (modified) llvm/test/CodeGen/NVPTX/param-align.ll (+36) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 982c191875750..63cbdb0acfab6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3232,9 +3232,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
             if (NumElts != 1)
               return std::nullopt;
             Align PartAlign =
-                (Offsets[parti] == 0 && PAL.getParamAlignment(i))
-                    ? PAL.getParamAlignment(i).value()
-                    : DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext()));
+                DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext()));
             return commonAlignment(PartAlign, Offsets[parti]);
           }();
           SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr,
diff --git a/llvm/test/CodeGen/NVPTX/param-align.ll b/llvm/test/CodeGen/NVPTX/param-align.ll
index 5435ee238c88d..6ef284aa4b5eb 100644
--- a/llvm/test/CodeGen/NVPTX/param-align.ll
+++ b/llvm/test/CodeGen/NVPTX/param-align.ll
@@ -69,3 +69,39 @@ define ptx_device void @t6() {
   call void %fp(ptr byval(i8) null);
   ret void
 }
+
+; CHECK: .func check_ptr_align1(
+; CHECK: 	ld.param.u64 	%rd1
+; CHECK: 	ret;
+define void @check_ptr_align1(ptr align 1 %_arg_ptr) {
+entry:
+  store i32 1, ptr %_arg_ptr, align 1
+  ret void
+}
+
+; CHECK: .func check_ptr_align2(
+; CHECK: 	ld.param.u64 	%rd1
+; CHECK: 	ret;
+define void @check_ptr_align2(ptr align 2 %_arg_ptr) {
+entry:
+  store i32 2, ptr %_arg_ptr, align 2
+  ret void
+}
+
+; CHECK: .func check_ptr_align4(
+; CHECK: 	ld.param.u64 	%rd1
+; CHECK: 	ret;
+define void @check_ptr_align4(ptr align 4 %_arg_ptr) {
+entry:
+  store i32 4, ptr %_arg_ptr, align 4
+  ret void
+}
+
+; CHECK: .func check_ptr_align8(
+; CHECK: 	ld.param.u64 	%rd1
+; CHECK: 	ret;
+define void @check_ptr_align8(ptr align 8 %_arg_ptr) {
+entry:
+  store i32 8, ptr %_arg_ptr, align 8
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/96793


More information about the llvm-commits mailing list