[flang-commits] [flang] [flang][cuda] Flatten memref descriptors in GPU kernel argument packing (PR #193651)
Razvan Lupusoru via flang-commits
flang-commits at lists.llvm.org
Thu Apr 23 08:18:31 PDT 2026
================
@@ -33,16 +34,41 @@ using namespace Fortran::runtime;
namespace {
+// Build the kernel argument array used for the CUDA kernel launch.
+//
+// For each operand:
+// - memref operands are unpacked into their descriptor scalar fields so the
+// host-side parameter list matches the NVVM-lowered device kernel signature
+// (gpu-to-nvvm expands each memref into 3 + 2*rank scalar parameters).
+// We delegate to MemRefDescriptor::unpack so we follow the canonical memref
+// descriptor layout owned by the LLVMCommon library;
+// - all other operands are passed through unchanged.
+//
+// The flattened values are materialized on the stack in a single struct
+// (preserving argument order), and a companion pointer array is populated with
+// the address of each field. That pointer array is what the CUDA launch
+// interface expects as kernelParams.
static mlir::Value createKernelArgArray(mlir::Location loc,
- mlir::ValueRange operands,
+ mlir::ValueRange origOperands,
+ mlir::ValueRange adaptedOperands,
mlir::PatternRewriter &rewriter) {
auto *ctx = rewriter.getContext();
- llvm::SmallVector<mlir::Type> structTypes(operands.size(), nullptr);
- for (auto [i, arg] : llvm::enumerate(operands))
- structTypes[i] = arg.getType();
+ llvm::SmallVector<mlir::Value, 8> flatValues;
+ flatValues.reserve(adaptedOperands.size());
+ for (auto [origArg, adaptedArg] :
+ llvm::zip_equal(origOperands, adaptedOperands)) {
+ if (auto memrefTy = mlir::dyn_cast<mlir::MemRefType>(origArg.getType())) {
+ mlir::MemRefDescriptor::unpack(rewriter, loc, adaptedArg, memrefTy,
----------------
razvanlupusoru wrote:
Looks great calling unpack. Just one more thing - can you double check what the useBarePtrCallConv is and whether it is ever used in the GPU lowering path? I see that it unpacks descriptor unless this flag is on: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp#L765
https://github.com/llvm/llvm-project/pull/193651
More information about the flang-commits
mailing list