[flang-commits] [flang] [flang][cuda] Avoid stack corruption when setting kernel launch parameters (PR #119469)

via flang-commits flang-commits at lists.llvm.org
Tue Dec 10 15:07:13 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (khaki3)

<details>
<summary>Changes</summary>

In order to get the pointer to a structure member, `getelementptr` typically requires two indices: one to indicate the structure itself, and another to specify the member's position. We are missing the former in `GPULaunchKernelConversion`, so generated code may cause stack corruption. This PR corrects the indices of a structure used as a kernel launch temp.


---
Full diff: https://github.com/llvm/llvm-project/pull/119469.diff


2 Files Affected:

- (modified) flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp (+4-1) 
- (modified) flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir (+1-1) 


``````````diff
diff --git a/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp b/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp
index 426cd52b7ef83e..60aa401e1cc8cc 100644
--- a/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp
@@ -42,6 +42,8 @@ static mlir::Value createKernelArgArray(mlir::Location loc,
   auto structTy = mlir::LLVM::LLVMStructType::getLiteral(ctx, structTypes);
   auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
   mlir::Type i32Ty = rewriter.getI32Type();
+  auto zero = rewriter.create<mlir::LLVM::ConstantOp>(
+      loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 0));
   auto one = rewriter.create<mlir::LLVM::ConstantOp>(
       loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 1));
   mlir::Value argStruct =
@@ -55,7 +57,8 @@ static mlir::Value createKernelArgArray(mlir::Location loc,
     auto indice = rewriter.create<mlir::LLVM::ConstantOp>(
         loc, i32Ty, rewriter.getIntegerAttr(i32Ty, i));
     mlir::Value structMember = rewriter.create<LLVM::GEPOp>(
-        loc, ptrTy, structTy, argStruct, mlir::ArrayRef<mlir::Value>({indice}));
+        loc, ptrTy, structTy, argStruct,
+        mlir::ArrayRef<mlir::Value>({zero, indice}));
     rewriter.create<LLVM::StoreOp>(loc, arg, structMember);
     mlir::Value arrayMember = rewriter.create<LLVM::GEPOp>(
         loc, ptrTy, ptrTy, argArray, mlir::ArrayRef<mlir::Value>({indice}));
diff --git a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
index accdeae30aa61c..3db2336c90a7d4 100644
--- a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
+++ b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
@@ -102,7 +102,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
 // CHECK: %[[STRUCT:.*]] = llvm.alloca %{{.*}} x !llvm.struct<(ptr)> : (i32) -> !llvm.ptr
 // CHECK: %[[PARAMS:.*]] = llvm.alloca %{{.*}} x !llvm.ptr : (i32) -> !llvm.ptr
 // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK: %[[STRUCT_PTR:.*]] = llvm.getelementptr %[[STRUCT]][%[[ZERO]]] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(ptr)>
+// CHECK: %[[STRUCT_PTR:.*]] = llvm.getelementptr %[[STRUCT]][%{{.*}}, {{.*}}] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(ptr)>
 // CHECK: llvm.store %{{.*}}, %[[STRUCT_PTR]] : !llvm.ptr, !llvm.ptr
 // CHECK: %[[PARAM_PTR:.*]] = llvm.getelementptr %[[PARAMS]][%[[ZERO]]] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.ptr
 // CHECK: llvm.store %[[STRUCT_PTR]], %[[PARAM_PTR]] : !llvm.ptr, !llvm.ptr

``````````

</details>


https://github.com/llvm/llvm-project/pull/119469


More information about the flang-commits mailing list