[flang-commits] [flang] [flang][cuda] Correctly allocate descriptor in managed memory when reboxing (PR #120795)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Fri Dec 20 13:20:18 PST 2024


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/120795

Reboxing might create a new in memory descriptor. If this one was allocate with managed memory, allocate the new one in managed memory as well. 

>From be03473b49512a807ac9dd97a5b78cf0b66f6507 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 20 Dec 2024 13:18:51 -0800
Subject: [PATCH] [flang][cuda] Correctly allocate descriptor in managed memory
 when reboxing

---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp | 11 ++--
 flang/test/Fir/CUDA/cuda-code-gen.mlir  | 70 +++++++++++++++++++++++++
 2 files changed, 78 insertions(+), 3 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index d09f47a20b33d8..9d911d6bfd4061 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -1725,13 +1725,17 @@ struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> {
 };
 
 static bool isDeviceAllocation(mlir::Value val) {
+  if (auto loadOp = mlir::dyn_cast_or_null<fir::LoadOp>(val.getDefiningOp()))
+    return isDeviceAllocation(loadOp.getMemref());
   if (auto convertOp =
           mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp()))
     val = convertOp.getValue();
   if (auto callOp = mlir::dyn_cast_or_null<fir::CallOp>(val.getDefiningOp()))
     if (callOp.getCallee() &&
-        callOp.getCallee().value().getRootReference().getValue().starts_with(
-            RTNAME_STRING(CUFMemAlloc)))
+        (callOp.getCallee().value().getRootReference().getValue().starts_with(
+             RTNAME_STRING(CUFMemAlloc)) ||
+         callOp.getCallee().value().getRootReference().getValue().starts_with(
+             RTNAME_STRING(CUFAllocDesciptor))))
       return true;
   return false;
 }
@@ -2045,7 +2049,8 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
     }
     dest = insertBaseAddress(rewriter, loc, dest, base);
     mlir::Value result =
-        placeInMemoryIfNotGlobalInit(rewriter, rebox.getLoc(), destBoxTy, dest);
+        placeInMemoryIfNotGlobalInit(rewriter, rebox.getLoc(), destBoxTy, dest,
+                                     isDeviceAllocation(rebox.getBox()));
     rewriter.replaceOp(rebox, result);
     return mlir::success();
   }
diff --git a/flang/test/Fir/CUDA/cuda-code-gen.mlir b/flang/test/Fir/CUDA/cuda-code-gen.mlir
index a34c2770c5f6c5..47c5667a14c95e 100644
--- a/flang/test/Fir/CUDA/cuda-code-gen.mlir
+++ b/flang/test/Fir/CUDA/cuda-code-gen.mlir
@@ -56,3 +56,73 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<f80 = dense<128> : vector<2xi64>
 // CHECK-LABEL: llvm.func @_QQmain()
 // CHECK: llvm.call @_FortranACUFMemAlloc
 // CHECK: llvm.call @_FortranACUFAllocDesciptor
+
+// -----
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<f80 = dense<128> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, i64 = dense<64> : vector<2xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f128 = dense<128> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i1 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>} {
+  func.func @_QQmain() attributes {fir.bindc_name = "p1"} {
+    %c1_i32 = arith.constant 1 : i32
+    %c0_i32 = arith.constant 0 : i32
+    %c16_i32 = arith.constant 16 : i32
+    %c1 = arith.constant 1 : index
+    %c0 = arith.constant 0 : index
+    %0 = fir.alloca i32 {bindc_name = "iblk", uniq_name = "_QFEiblk"}
+    %1 = fir.alloca i32 {bindc_name = "ithr", uniq_name = "_QFEithr"}
+    %2 = fir.address_of(@_QQclX64756D6D792E6D6C697200) : !fir.ref<!fir.char<1,11>>
+    %c14_i32 = arith.constant 14 : i32
+    %c72 = arith.constant 72 : index
+    %3 = fir.convert %c72 : (index) -> i64
+    %4 = fir.convert %2 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
+    %5 = fir.call @_FortranACUFAllocDesciptor(%3, %4, %c14_i32) : (i64, !fir.ref<i8>, i32) -> !fir.ref<!fir.box<none>>
+    %6 = fir.convert %5 : (!fir.ref<!fir.box<none>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
+    %7 = fir.zero_bits !fir.heap<!fir.array<?x?xf32>>
+    %8 = fircg.ext_embox %7(%c0, %c0) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?x?xf32>>, index, index) -> !fir.box<!fir.heap<!fir.array<?x?xf32>>>
+    fir.store %8 to %6 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
+    %9 = fir.address_of(@_QQclX64756D6D792E6D6C697200) : !fir.ref<!fir.char<1,11>>
+    %c20_i32 = arith.constant 20 : i32
+    %c48 = arith.constant 48 : index
+    %10 = fir.convert %c48 : (index) -> i64
+    %11 = fir.convert %9 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
+    %12 = fir.call @_FortranACUFAllocDesciptor(%10, %11, %c20_i32) : (i64, !fir.ref<i8>, i32) -> !fir.ref<!fir.box<none>>
+    %13 = fir.convert %12 : (!fir.ref<!fir.box<none>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+    %14 = fir.zero_bits !fir.heap<!fir.array<?xf32>>
+    %15 = fircg.ext_embox %14(%c0) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xf32>>, index) -> !fir.box<!fir.heap<!fir.array<?xf32>>>
+    fir.store %15 to %13 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+    %16 = fir.convert %6 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<none>>
+    %17 = fir.convert %c1 : (index) -> i64
+    %18 = fir.convert %c16_i32 : (i32) -> i64
+    %19 = fir.call @_FortranAAllocatableSetBounds(%16, %c0_i32, %17, %18) fastmath<contract> : (!fir.ref<!fir.box<none>>, i32, i64, i64) -> none
+    %20 = fir.call @_FortranAAllocatableSetBounds(%16, %c1_i32, %17, %18) fastmath<contract> : (!fir.ref<!fir.box<none>>, i32, i64, i64) -> none
+    %21 = fir.address_of(@_QQclX64756D6D792E6D6C697200) : !fir.ref<!fir.char<1,11>>
+    %c31_i32 = arith.constant 31 : i32
+    %false = arith.constant false
+    %22 = fir.absent !fir.box<none>
+    %c-1_i64 = arith.constant -1 : i64
+    %23 = fir.convert %6 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<none>>
+    %24 = fir.convert %21 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
+    %25 = fir.call @_FortranACUFAllocatableAllocate(%23, %c-1_i64, %false, %22, %24, %c31_i32) : (!fir.ref<!fir.box<none>>, i64, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
+    %26 = fir.convert %13 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+    %27 = fir.call @_FortranAAllocatableSetBounds(%26, %c0_i32, %17, %18) fastmath<contract> : (!fir.ref<!fir.box<none>>, i32, i64, i64) -> none
+    %28 = fir.address_of(@_QQclX64756D6D792E6D6C697200) : !fir.ref<!fir.char<1,11>>
+    %c34_i32 = arith.constant 34 : i32
+    %false_0 = arith.constant false
+    %29 = fir.absent !fir.box<none>
+    %c-1_i64_1 = arith.constant -1 : i64
+    %30 = fir.convert %13 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+    %31 = fir.convert %28 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
+    %32 = fir.call @_FortranACUFAllocatableAllocate(%30, %c-1_i64_1, %false_0, %29, %31, %c34_i32) : (!fir.ref<!fir.box<none>>, i64, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
+    %33 = fir.load %6 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
+    %34 = fircg.ext_rebox %33 : (!fir.box<!fir.heap<!fir.array<?x?xf32>>>) -> !fir.box<!fir.array<?x?xf32>>
+    return
+  }
+  func.func private @_FortranAAllocatableSetBounds(!fir.ref<!fir.box<none>>, i32, i64, i64) -> none attributes {fir.runtime}
+  fir.global linkonce @_QQclX64756D6D792E6D6C697200 constant : !fir.char<1,11> {
+    %0 = fir.string_lit "dummy.mlir\00"(11) : !fir.char<1,11>
+    fir.has_value %0 : !fir.char<1,11>
+  }
+  func.func private @_FortranACUFAllocDesciptor(i64, !fir.ref<i8>, i32) -> !fir.ref<!fir.box<none>> attributes {fir.runtime}
+  func.func private @_FortranACUFAllocatableAllocate(!fir.ref<!fir.box<none>>, i64, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32 attributes {fir.runtime}
+}
+
+// CHECK-LABEL: llvm.func @_QQmain()
+// CHECK-COUNT-4: llvm.call @_FortranACUFAllocDesciptor



More information about the flang-commits mailing list