[llvm] [NVTPX] Copy kernel arguments as byte array (PR #110356)

Michael Kuron via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 28 06:30:41 PDT 2024


https://github.com/mkuron updated https://github.com/llvm/llvm-project/pull/110356

>From 1989e313c49774ed1c237ca0bd66762c3a5ab8a5 Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron at gmx.de>
Date: Sat, 28 Sep 2024 12:57:43 +0200
Subject: [PATCH] [NVTPX] Copy kernel arguments as byte array

Ensures that struct padding is not skipped, as it may contain actual
data if the struct is really a union.

Fixes #53710
---
 llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp    |  8 ++-
 llvm/test/CodeGen/NVPTX/lower-args.ll       |  2 +-
 llvm/test/CodeGen/NVPTX/lower-byval-args.ll | 64 ++++++++++-----------
 3 files changed, 40 insertions(+), 34 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index 082546c4dd72f8..2b67b02516ae4e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -626,8 +626,14 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
     // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
     // addrspacecast preserves alignment.  Since params are constant, this load
     // is definitely not volatile.
+    const auto StructBytes = *AllocA->getAllocationSize(DL);
+    const auto ChunkBytes = (StructBytes % 8 == 0) ? 8 :
+                            (StructBytes % 4 == 0) ? 4 :
+                            (StructBytes % 2 == 0) ? 2 : 1;
+    Type *ChunkType = Type::getIntNTy(Func->getContext(), 8 * ChunkBytes);
+    Type *OpaqueType = ArrayType::get(ChunkType, StructBytes / ChunkBytes);
     LoadInst *LI =
-        new LoadInst(StructType, ArgInParam, Arg->getName(),
+        new LoadInst(OpaqueType, ArgInParam, Arg->getName(),
                      /*isVolatile=*/false, AllocA->getAlign(), FirstInst);
     new StoreInst(LI, AllocA, FirstInst);
   }
diff --git a/llvm/test/CodeGen/NVPTX/lower-args.ll b/llvm/test/CodeGen/NVPTX/lower-args.ll
index 029f1944d596b3..9a306036044be4 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args.ll
@@ -14,7 +14,7 @@ target triple = "nvptx64-nvidia-cuda"
 ; COMMON-LABEL: load_alignment
 define void @load_alignment(ptr nocapture readonly byval(%class.outer) align 8 %arg) {
 entry:
-; IR: load %class.outer, ptr addrspace(101)
+; IR: load [3 x i64], ptr addrspace(101)
 ; IR-SAME: align 8
 ; PTX: ld.param.u64
 ; PTX-NOT: ld.param.u8
diff --git a/llvm/test/CodeGen/NVPTX/lower-byval-args.ll b/llvm/test/CodeGen/NVPTX/lower-byval-args.ll
index a414a6c41cd5b2..91ff876f0b718f 100644
--- a/llvm/test/CodeGen/NVPTX/lower-byval-args.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-byval-args.ll
@@ -88,8 +88,8 @@ define dso_local void @read_only_gep_asc0(ptr nocapture noundef writeonly %out,
 ; COMMON-NEXT:  [[ENTRY:.*:]]
 ; COMMON-NEXT:    [[S3:%.*]] = alloca [[STRUCT_S]], align 4
 ; COMMON-NEXT:    [[S4:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
-; COMMON-NEXT:    [[S5:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[S4]], align 4
-; COMMON-NEXT:    store [[STRUCT_S]] [[S5]], ptr [[S3]], align 4
+; COMMON-NEXT:    [[S5:%.*]] = load [1 x i64], ptr addrspace(101) [[S4]], align 4
+; COMMON-NEXT:    store [1 x i64] [[S5]], ptr [[S3]], align 8
 ; COMMON-NEXT:    [[OUT1:%.*]] = addrspacecast ptr [[OUT]] to ptr addrspace(1)
 ; COMMON-NEXT:    [[OUT2:%.*]] = addrspacecast ptr addrspace(1) [[OUT1]] to ptr
 ; COMMON-NEXT:    [[B:%.*]] = getelementptr inbounds nuw i8, ptr [[S3]], i64 4
@@ -115,8 +115,8 @@ define dso_local void @escape_ptr(ptr nocapture noundef readnone %out, ptr nound
 ; COMMON-NEXT:  [[ENTRY:.*:]]
 ; COMMON-NEXT:    [[S3:%.*]] = alloca [[STRUCT_S]], align 4
 ; COMMON-NEXT:    [[S4:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
-; COMMON-NEXT:    [[S5:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[S4]], align 4
-; COMMON-NEXT:    store [[STRUCT_S]] [[S5]], ptr [[S3]], align 4
+; COMMON-NEXT:    [[S5:%.*]] = load [1 x i64], ptr addrspace(101) [[S4]], align 4
+; COMMON-NEXT:    store [1 x i64] [[S5]], ptr [[S3]], align 8
 ; COMMON-NEXT:    [[OUT1:%.*]] = addrspacecast ptr [[OUT]] to ptr addrspace(1)
 ; COMMON-NEXT:    [[OUT2:%.*]] = addrspacecast ptr addrspace(1) [[OUT1]] to ptr
 ; COMMON-NEXT:    call void @_Z6escapePv(ptr noundef nonnull [[S3]])
@@ -134,8 +134,8 @@ define dso_local void @escape_ptr_gep(ptr nocapture noundef readnone %out, ptr n
 ; COMMON-NEXT:  [[ENTRY:.*:]]
 ; COMMON-NEXT:    [[S3:%.*]] = alloca [[STRUCT_S]], align 4
 ; COMMON-NEXT:    [[S4:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
-; COMMON-NEXT:    [[S5:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[S4]], align 4
-; COMMON-NEXT:    store [[STRUCT_S]] [[S5]], ptr [[S3]], align 4
+; COMMON-NEXT:    [[S5:%.*]] = load [1 x i64], ptr addrspace(101) [[S4]], align 4
+; COMMON-NEXT:    store [1 x i64] [[S5]], ptr [[S3]], align 8
 ; COMMON-NEXT:    [[OUT1:%.*]] = addrspacecast ptr [[OUT]] to ptr addrspace(1)
 ; COMMON-NEXT:    [[OUT2:%.*]] = addrspacecast ptr addrspace(1) [[OUT1]] to ptr
 ; COMMON-NEXT:    [[B:%.*]] = getelementptr inbounds nuw i8, ptr [[S3]], i64 4
@@ -155,8 +155,8 @@ define dso_local void @escape_ptr_store(ptr nocapture noundef writeonly %out, pt
 ; COMMON-NEXT:  [[ENTRY:.*:]]
 ; COMMON-NEXT:    [[S3:%.*]] = alloca [[STRUCT_S]], align 4
 ; COMMON-NEXT:    [[S4:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
-; COMMON-NEXT:    [[S5:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[S4]], align 4
-; COMMON-NEXT:    store [[STRUCT_S]] [[S5]], ptr [[S3]], align 4
+; COMMON-NEXT:    [[S5:%.*]] = load [1 x i64], ptr addrspace(101) [[S4]], align 4
+; COMMON-NEXT:    store [1 x i64] [[S5]], ptr [[S3]], align 8
 ; COMMON-NEXT:    [[OUT1:%.*]] = addrspacecast ptr [[OUT]] to ptr addrspace(1)
 ; COMMON-NEXT:    [[OUT2:%.*]] = addrspacecast ptr addrspace(1) [[OUT1]] to ptr
 ; COMMON-NEXT:    store ptr [[S3]], ptr [[OUT2]], align 8
@@ -174,8 +174,8 @@ define dso_local void @escape_ptr_gep_store(ptr nocapture noundef writeonly %out
 ; COMMON-NEXT:  [[ENTRY:.*:]]
 ; COMMON-NEXT:    [[S3:%.*]] = alloca [[STRUCT_S]], align 4
 ; COMMON-NEXT:    [[S4:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
-; COMMON-NEXT:    [[S5:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[S4]], align 4
-; COMMON-NEXT:    store [[STRUCT_S]] [[S5]], ptr [[S3]], align 4
+; COMMON-NEXT:    [[S5:%.*]] = load [1 x i64], ptr addrspace(101) [[S4]], align 4
+; COMMON-NEXT:    store [1 x i64] [[S5]], ptr [[S3]], align 8
 ; COMMON-NEXT:    [[OUT1:%.*]] = addrspacecast ptr [[OUT]] to ptr addrspace(1)
 ; COMMON-NEXT:    [[OUT2:%.*]] = addrspacecast ptr addrspace(1) [[OUT1]] to ptr
 ; COMMON-NEXT:    [[B:%.*]] = getelementptr inbounds nuw i8, ptr [[S3]], i64 4
@@ -195,8 +195,8 @@ define dso_local void @escape_ptrtoint(ptr nocapture noundef writeonly %out, ptr
 ; COMMON-NEXT:  [[ENTRY:.*:]]
 ; COMMON-NEXT:    [[S3:%.*]] = alloca [[STRUCT_S]], align 4
 ; COMMON-NEXT:    [[S4:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
-; COMMON-NEXT:    [[S5:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[S4]], align 4
-; COMMON-NEXT:    store [[STRUCT_S]] [[S5]], ptr [[S3]], align 4
+; COMMON-NEXT:    [[S5:%.*]] = load [1 x i64], ptr addrspace(101) [[S4]], align 4
+; COMMON-NEXT:    store [1 x i64] [[S5]], ptr [[S3]], align 8
 ; COMMON-NEXT:    [[OUT1:%.*]] = addrspacecast ptr [[OUT]] to ptr addrspace(1)
 ; COMMON-NEXT:    [[OUT2:%.*]] = addrspacecast ptr addrspace(1) [[OUT1]] to ptr
 ; COMMON-NEXT:    [[I:%.*]] = ptrtoint ptr [[S3]] to i64
@@ -232,8 +232,8 @@ define dso_local void @memcpy_to_param(ptr nocapture noundef readonly %in, ptr n
 ; COMMON-NEXT:  [[ENTRY:.*:]]
 ; COMMON-NEXT:    [[S3:%.*]] = alloca [[STRUCT_S]], align 4
 ; COMMON-NEXT:    [[S4:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
-; COMMON-NEXT:    [[S5:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[S4]], align 4
-; COMMON-NEXT:    store [[STRUCT_S]] [[S5]], ptr [[S3]], align 4
+; COMMON-NEXT:    [[S5:%.*]] = load [1 x i64], ptr addrspace(101) [[S4]], align 4
+; COMMON-NEXT:    store [1 x i64] [[S5]], ptr [[S3]], align 8
 ; COMMON-NEXT:    [[IN1:%.*]] = addrspacecast ptr [[IN]] to ptr addrspace(1)
 ; COMMON-NEXT:    [[IN2:%.*]] = addrspacecast ptr addrspace(1) [[IN1]] to ptr
 ; COMMON-NEXT:    tail call void @llvm.memcpy.p0.p0.i64(ptr [[S3]], ptr [[IN2]], i64 16, i1 true)
@@ -251,8 +251,8 @@ define dso_local void @copy_on_store(ptr nocapture noundef readonly %in, ptr noc
 ; COMMON-NEXT:  [[BB:.*:]]
 ; COMMON-NEXT:    [[S3:%.*]] = alloca [[STRUCT_S]], align 4
 ; COMMON-NEXT:    [[S4:%.*]] = addrspacecast ptr [[S]] to ptr addrspace(101)
-; COMMON-NEXT:    [[S5:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[S4]], align 4
-; COMMON-NEXT:    store [[STRUCT_S]] [[S5]], ptr [[S3]], align 4
+; COMMON-NEXT:    [[S5:%.*]] = load [1 x i64], ptr addrspace(101) [[S4]], align 4
+; COMMON-NEXT:    store [1 x i64] [[S5]], ptr [[S3]], align 8
 ; COMMON-NEXT:    [[IN1:%.*]] = addrspacecast ptr [[IN]] to ptr addrspace(1)
 ; COMMON-NEXT:    [[IN2:%.*]] = addrspacecast ptr addrspace(1) [[IN1]] to ptr
 ; COMMON-NEXT:    [[I:%.*]] = load i32, ptr [[IN2]], align 4
@@ -273,12 +273,12 @@ define void @test_select(ptr byval(i32) align 4 %input1, ptr byval(i32) %input2,
 ; SM_60-NEXT:    [[OUT8:%.*]] = addrspacecast ptr addrspace(1) [[OUT7]] to ptr
 ; SM_60-NEXT:    [[INPUT24:%.*]] = alloca i32, align 4
 ; SM_60-NEXT:    [[INPUT25:%.*]] = addrspacecast ptr [[INPUT2]] to ptr addrspace(101)
-; SM_60-NEXT:    [[INPUT26:%.*]] = load i32, ptr addrspace(101) [[INPUT25]], align 4
-; SM_60-NEXT:    store i32 [[INPUT26]], ptr [[INPUT24]], align 4
+; SM_60-NEXT:    [[INPUT26:%.*]] = load [1 x i32], ptr addrspace(101) [[INPUT25]], align 4
+; SM_60-NEXT:    store [1 x i32] [[INPUT26]], ptr [[INPUT24]], align 4
 ; SM_60-NEXT:    [[INPUT11:%.*]] = alloca i32, align 4
 ; SM_60-NEXT:    [[INPUT12:%.*]] = addrspacecast ptr [[INPUT1]] to ptr addrspace(101)
-; SM_60-NEXT:    [[INPUT13:%.*]] = load i32, ptr addrspace(101) [[INPUT12]], align 4
-; SM_60-NEXT:    store i32 [[INPUT13]], ptr [[INPUT11]], align 4
+; SM_60-NEXT:    [[INPUT13:%.*]] = load [1 x i32], ptr addrspace(101) [[INPUT12]], align 4
+; SM_60-NEXT:    store [1 x i32] [[INPUT13]], ptr [[INPUT11]], align 4
 ; SM_60-NEXT:    [[PTRNEW:%.*]] = select i1 [[COND]], ptr [[INPUT11]], ptr [[INPUT24]]
 ; SM_60-NEXT:    [[VALLOADED:%.*]] = load i32, ptr [[PTRNEW]], align 4
 ; SM_60-NEXT:    store i32 [[VALLOADED]], ptr [[OUT8]], align 4
@@ -313,12 +313,12 @@ define void @test_select_write(ptr byval(i32) align 4 %input1, ptr byval(i32) %i
 ; COMMON-NEXT:    [[OUT8:%.*]] = addrspacecast ptr addrspace(1) [[OUT7]] to ptr
 ; COMMON-NEXT:    [[INPUT24:%.*]] = alloca i32, align 4
 ; COMMON-NEXT:    [[INPUT25:%.*]] = addrspacecast ptr [[INPUT2]] to ptr addrspace(101)
-; COMMON-NEXT:    [[INPUT26:%.*]] = load i32, ptr addrspace(101) [[INPUT25]], align 4
-; COMMON-NEXT:    store i32 [[INPUT26]], ptr [[INPUT24]], align 4
+; COMMON-NEXT:    [[INPUT26:%.*]] = load [1 x i32], ptr addrspace(101) [[INPUT25]], align 4
+; COMMON-NEXT:    store [1 x i32] [[INPUT26]], ptr [[INPUT24]], align 4
 ; COMMON-NEXT:    [[INPUT11:%.*]] = alloca i32, align 4
 ; COMMON-NEXT:    [[INPUT12:%.*]] = addrspacecast ptr [[INPUT1]] to ptr addrspace(101)
-; COMMON-NEXT:    [[INPUT13:%.*]] = load i32, ptr addrspace(101) [[INPUT12]], align 4
-; COMMON-NEXT:    store i32 [[INPUT13]], ptr [[INPUT11]], align 4
+; COMMON-NEXT:    [[INPUT13:%.*]] = load [1 x i32], ptr addrspace(101) [[INPUT12]], align 4
+; COMMON-NEXT:    store [1 x i32] [[INPUT13]], ptr [[INPUT11]], align 4
 ; COMMON-NEXT:    [[PTRNEW:%.*]] = select i1 [[COND]], ptr [[INPUT11]], ptr [[INPUT24]]
 ; COMMON-NEXT:    store i32 1, ptr [[PTRNEW]], align 4
 ; COMMON-NEXT:    ret void
@@ -337,12 +337,12 @@ define void @test_phi(ptr byval(%struct.S) align 4 %input1, ptr byval(%struct.S)
 ; SM_60-NEXT:    [[INOUT8:%.*]] = addrspacecast ptr addrspace(1) [[INOUT7]] to ptr
 ; SM_60-NEXT:    [[INPUT24:%.*]] = alloca [[STRUCT_S]], align 8
 ; SM_60-NEXT:    [[INPUT25:%.*]] = addrspacecast ptr [[INPUT2]] to ptr addrspace(101)
-; SM_60-NEXT:    [[INPUT26:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[INPUT25]], align 8
-; SM_60-NEXT:    store [[STRUCT_S]] [[INPUT26]], ptr [[INPUT24]], align 4
+; SM_60-NEXT:    [[INPUT26:%.*]] = load [1 x i64], ptr addrspace(101) [[INPUT25]], align 8
+; SM_60-NEXT:    store [1 x i64] [[INPUT26]], ptr [[INPUT24]], align 8
 ; SM_60-NEXT:    [[INPUT11:%.*]] = alloca [[STRUCT_S]], align 4
 ; SM_60-NEXT:    [[INPUT12:%.*]] = addrspacecast ptr [[INPUT1]] to ptr addrspace(101)
-; SM_60-NEXT:    [[INPUT13:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[INPUT12]], align 4
-; SM_60-NEXT:    store [[STRUCT_S]] [[INPUT13]], ptr [[INPUT11]], align 4
+; SM_60-NEXT:    [[INPUT13:%.*]] = load [1 x i64], ptr addrspace(101) [[INPUT12]], align 4
+; SM_60-NEXT:    store [1 x i64] [[INPUT13]], ptr [[INPUT11]], align 8
 ; SM_60-NEXT:    br i1 [[COND]], label %[[FIRST:.*]], label %[[SECOND:.*]]
 ; SM_60:       [[FIRST]]:
 ; SM_60-NEXT:    [[PTR1:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT11]], i32 0, i32 0
@@ -402,12 +402,12 @@ define void @test_phi_write(ptr byval(%struct.S) align 4 %input1, ptr byval(%str
 ; COMMON-NEXT:  [[BB:.*:]]
 ; COMMON-NEXT:    [[INPUT24:%.*]] = alloca [[STRUCT_S]], align 8
 ; COMMON-NEXT:    [[INPUT25:%.*]] = addrspacecast ptr [[INPUT2]] to ptr addrspace(101)
-; COMMON-NEXT:    [[INPUT26:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[INPUT25]], align 8
-; COMMON-NEXT:    store [[STRUCT_S]] [[INPUT26]], ptr [[INPUT24]], align 4
+; COMMON-NEXT:    [[INPUT26:%.*]] = load [1 x i64], ptr addrspace(101) [[INPUT25]], align 8
+; COMMON-NEXT:    store [1 x i64] [[INPUT26]], ptr [[INPUT24]], align 8
 ; COMMON-NEXT:    [[INPUT11:%.*]] = alloca [[STRUCT_S]], align 4
 ; COMMON-NEXT:    [[INPUT12:%.*]] = addrspacecast ptr [[INPUT1]] to ptr addrspace(101)
-; COMMON-NEXT:    [[INPUT13:%.*]] = load [[STRUCT_S]], ptr addrspace(101) [[INPUT12]], align 4
-; COMMON-NEXT:    store [[STRUCT_S]] [[INPUT13]], ptr [[INPUT11]], align 4
+; COMMON-NEXT:    [[INPUT13:%.*]] = load [1 x i64], ptr addrspace(101) [[INPUT12]], align 4
+; COMMON-NEXT:    store [1 x i64] [[INPUT13]], ptr [[INPUT11]], align 8
 ; COMMON-NEXT:    br i1 [[COND]], label %[[FIRST:.*]], label %[[SECOND:.*]]
 ; COMMON:       [[FIRST]]:
 ; COMMON-NEXT:    [[PTR1:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT11]], i32 0, i32 0



More information about the llvm-commits mailing list