[flang-commits] [flang] [flang][cuda] Avoid stack corruption when setting kernel launch parameters (PR #119469)
via flang-commits
flang-commits at lists.llvm.org
Tue Dec 10 15:07:13 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: None (khaki3)
<details>
<summary>Changes</summary>
In order to get the pointer to a structure member, `getelementptr` typically requires two indices: one to indicate the structure itself, and another to specify the member's position. We are missing the former in `GPULaunchKernelConversion`, so generated code may cause stack corruption. This PR corrects the indices of a structure used as a kernel launch temp.
---
Full diff: https://github.com/llvm/llvm-project/pull/119469.diff
2 Files Affected:
- (modified) flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp (+4-1)
- (modified) flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir (+1-1)
``````````diff
diff --git a/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp b/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp
index 426cd52b7ef83e..60aa401e1cc8cc 100644
--- a/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp
@@ -42,6 +42,8 @@ static mlir::Value createKernelArgArray(mlir::Location loc,
auto structTy = mlir::LLVM::LLVMStructType::getLiteral(ctx, structTypes);
auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
mlir::Type i32Ty = rewriter.getI32Type();
+ auto zero = rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 0));
auto one = rewriter.create<mlir::LLVM::ConstantOp>(
loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 1));
mlir::Value argStruct =
@@ -55,7 +57,8 @@ static mlir::Value createKernelArgArray(mlir::Location loc,
auto indice = rewriter.create<mlir::LLVM::ConstantOp>(
loc, i32Ty, rewriter.getIntegerAttr(i32Ty, i));
mlir::Value structMember = rewriter.create<LLVM::GEPOp>(
- loc, ptrTy, structTy, argStruct, mlir::ArrayRef<mlir::Value>({indice}));
+ loc, ptrTy, structTy, argStruct,
+ mlir::ArrayRef<mlir::Value>({zero, indice}));
rewriter.create<LLVM::StoreOp>(loc, arg, structMember);
mlir::Value arrayMember = rewriter.create<LLVM::GEPOp>(
loc, ptrTy, ptrTy, argArray, mlir::ArrayRef<mlir::Value>({indice}));
diff --git a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
index accdeae30aa61c..3db2336c90a7d4 100644
--- a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
+++ b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
@@ -102,7 +102,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
// CHECK: %[[STRUCT:.*]] = llvm.alloca %{{.*}} x !llvm.struct<(ptr)> : (i32) -> !llvm.ptr
// CHECK: %[[PARAMS:.*]] = llvm.alloca %{{.*}} x !llvm.ptr : (i32) -> !llvm.ptr
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK: %[[STRUCT_PTR:.*]] = llvm.getelementptr %[[STRUCT]][%[[ZERO]]] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(ptr)>
+// CHECK: %[[STRUCT_PTR:.*]] = llvm.getelementptr %[[STRUCT]][%{{.*}}, {{.*}}] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(ptr)>
// CHECK: llvm.store %{{.*}}, %[[STRUCT_PTR]] : !llvm.ptr, !llvm.ptr
// CHECK: %[[PARAM_PTR:.*]] = llvm.getelementptr %[[PARAMS]][%[[ZERO]]] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.ptr
// CHECK: llvm.store %[[STRUCT_PTR]], %[[PARAM_PTR]] : !llvm.ptr, !llvm.ptr
``````````
</details>
https://github.com/llvm/llvm-project/pull/119469
More information about the flang-commits
mailing list