[flang-commits] [flang] [flang][cuda] Include constant extents in cuf.alloc size computation (PR #194288)

via flang-commits flang-commits at lists.llvm.org
Sun Apr 26 22:20:48 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: khaki3

<details>
<summary>Changes</summary>

Example:
```fortran
real, device :: arr(n, 4, m)
```

For arrays with mixed constant and dynamic extents (`!fir.array<?x4x?xf32>`), the allocation size only includes the dynamic extents (`n * m * 4 bytes`), missing the constant extent `4`. This under-allocates by a factor of 4, causing `cudaErrorIllegalAddress`.

Fix: After multiplying the dynamic extents, also multiply in constant extents from the type, producing the correct size (`n * m * 4 * 4 bytes`).

---
Full diff: https://github.com/llvm/llvm-project/pull/194288.diff


2 Files Affected:

- (modified) flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp (+8) 
- (modified) flang/test/Fir/CUDA/cuda-alloc-free.fir (+28) 


``````````diff
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp
index 1b11ba99f8d4b..a3a6e575b0b98 100644
--- a/flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp
@@ -206,6 +206,14 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
                 rewriter, loc, nbElem,
                 builder.loadIfRef(loc, op.getShape()[i]));
           }
+          for (auto extent : seqTy.getShape()) {
+            if (extent != fir::SequenceType::getUnknownExtent()) {
+              nbElem = mlir::arith::MulIOp::create(
+                  rewriter, loc, nbElem,
+                  builder.createIntegerConstant(loc, builder.getIndexType(),
+                                                extent));
+            }
+          }
         } else {
           nbElem = builder.createIntegerConstant(loc, builder.getIndexType(),
                                                  seqTy.getConstantArraySize());
diff --git a/flang/test/Fir/CUDA/cuda-alloc-free.fir b/flang/test/Fir/CUDA/cuda-alloc-free.fir
index 4872f80071368..f865cd3d7c961 100644
--- a/flang/test/Fir/CUDA/cuda-alloc-free.fir
+++ b/flang/test/Fir/CUDA/cuda-alloc-free.fir
@@ -114,4 +114,32 @@ func.func @_QQalloc_char2() {
 // CHECK: %[[BYTES_40:.*]] = fir.convert %c40{{.*}} : (index) -> i64
 // CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[BYTES_40]], %{{.*}}, %{{.*}}, %{{.*}}) {cuf.data_attr = #cuf.cuda<device>} : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
 
+// Test that cuf.alloc with mixed constant and dynamic extents includes
+// the constant extents in the allocation size. For dimension(n, 4, m),
+// size must be n * m * 4(constant) * 4(bytes), not n * m * 4(bytes).
+func.func @_QPsub_mixed(%arg0: !fir.ref<i32>, %arg1: !fir.ref<i32>) {
+  %0 = fir.load %arg0 : !fir.ref<i32>
+  %1 = fir.convert %0 : (i32) -> index
+  %c0 = arith.constant 0 : index
+  %2 = arith.cmpi sgt, %1, %c0 : index
+  %n = arith.select %2, %1, %c0 : index
+  %3 = fir.load %arg1 : !fir.ref<i32>
+  %4 = fir.convert %3 : (i32) -> index
+  %5 = arith.cmpi sgt, %4, %c0 : index
+  %m = arith.select %5, %4, %c0 : index
+  %6 = cuf.alloc !fir.array<?x4x?xf32>, %n, %m : index, index {bindc_name = "arr", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub_mixedEarr"} -> !fir.ref<!fir.array<?x4x?xf32>>
+  cuf.free %6 : !fir.ref<!fir.array<?x4x?xf32>> {data_attr = #cuf.cuda<device>}
+  return
+}
+
+// CHECK-LABEL: func.func @_QPsub_mixed
+// CHECK: %[[N:.*]] = arith.select
+// CHECK: %[[M:.*]] = arith.select
+// CHECK: %[[NM:.*]] = arith.muli %[[N]], %[[M]] : index
+// CHECK: %[[NM4:.*]] = arith.muli %[[NM]], %c4{{.*}} : index
+// CHECK: %[[BYTES:.*]] = arith.muli %[[NM4]], %c4{{.*}} : index
+// CHECK: %[[CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
+// CHECK: fir.call @_FortranACUFMemAlloc(%[[CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) {cuf.data_attr = #cuf.cuda<device>}
+// CHECK: fir.call @_FortranACUFMemFree
+
 } // end module

``````````

</details>


https://github.com/llvm/llvm-project/pull/194288


More information about the flang-commits mailing list