[flang-commits] [flang] [flang][cuda] Flatten memref descriptors in GPU kernel argument packing (PR #193651)
Zhen Wang via flang-commits
flang-commits at lists.llvm.org
Thu Apr 23 09:09:28 PDT 2026
================
@@ -33,16 +34,41 @@ using namespace Fortran::runtime;
namespace {
+// Build the kernel argument array used for the CUDA kernel launch.
+//
+// For each operand:
+// - memref operands are unpacked into their descriptor scalar fields so the
+// host-side parameter list matches the NVVM-lowered device kernel signature
+// (gpu-to-nvvm expands each memref into 3 + 2*rank scalar parameters).
+// We delegate to MemRefDescriptor::unpack so we follow the canonical memref
+// descriptor layout owned by the LLVMCommon library;
+// - all other operands are passed through unchanged.
+//
+// The flattened values are materialized on the stack in a single struct
+// (preserving argument order), and a companion pointer array is populated with
+// the address of each field. That pointer array is what the CUDA launch
+// interface expects as kernelParams.
static mlir::Value createKernelArgArray(mlir::Location loc,
- mlir::ValueRange operands,
+ mlir::ValueRange origOperands,
+ mlir::ValueRange adaptedOperands,
mlir::PatternRewriter &rewriter) {
auto *ctx = rewriter.getContext();
- llvm::SmallVector<mlir::Type> structTypes(operands.size(), nullptr);
- for (auto [i, arg] : llvm::enumerate(operands))
- structTypes[i] = arg.getType();
+ llvm::SmallVector<mlir::Value, 8> flatValues;
+ flatValues.reserve(adaptedOperands.size());
+ for (auto [origArg, adaptedArg] :
+ llvm::zip_equal(origOperands, adaptedOperands)) {
+ if (auto memrefTy = mlir::dyn_cast<mlir::MemRefType>(origArg.getType())) {
+ mlir::MemRefDescriptor::unpack(rewriter, loc, adaptedArg, memrefTy,
----------------
wangzpgi wrote:
Good catch. Today useBarePtrCallConv is off in the CUDA Fortran pipeline, so MemRefDescriptor::unpack happens to match. But rather than depending on that, I will switch to LLVMTypeConverter::promoteOperands — it internally branches on useBarePtrCallConv and also handles UnrankedMemRefType, so host-side packing stays consistent with whatever the type converter is configured to do.
https://github.com/llvm/llvm-project/pull/193651
More information about the flang-commits
mailing list