[llvm-branch-commits] [flang] [flang][cuda] Convert module allocation/deallocation to runtime calls (PR #109214)

Valentin Clement バレンタイン クレメン via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Sep 18 15:46:33 PDT 2024


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/109214

>From be4731f339d6fd9b45cd7cc93e3dd8ff83e80576 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 18 Sep 2024 15:42:19 -0700
Subject: [PATCH] [flang][cuda] Convert module allocation/deallocation to
 runtime calls

---
 .../Optimizer/Transforms/CufOpConversion.cpp  | 59 +++++++++++--------
 flang/test/Fir/CUDA/cuda-allocate.fir         | 40 ++++++++++++-
 2 files changed, 74 insertions(+), 25 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
index 2dc37f4df3aeec..ac796e83b07078 100644
--- a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
@@ -14,6 +14,7 @@
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/Support/DataLayout.h"
+#include "flang/Runtime/CUDA/allocatable.h"
 #include "flang/Runtime/CUDA/common.h"
 #include "flang/Runtime/CUDA/descriptor.h"
 #include "flang/Runtime/CUDA/memory.h"
@@ -35,13 +36,19 @@ using namespace Fortran::runtime::cuda;
 namespace {
 
 template <typename OpTy>
-static bool needDoubleDescriptor(OpTy op) {
+static bool isPinned(OpTy op) {
+  if (op.getDataAttr() && *op.getDataAttr() == cuf::DataAttribute::Pinned)
+    return true;
+  return false;
+}
+
+template <typename OpTy>
+static bool hasDoubleDescriptors(OpTy op) {
   if (auto declareOp =
           mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp())) {
     if (mlir::isa_and_nonnull<fir::AddrOfOp>(
             declareOp.getMemref().getDefiningOp())) {
-      if (declareOp.getDataAttr() &&
-          *declareOp.getDataAttr() == cuf::DataAttribute::Pinned)
+      if (isPinned(declareOp))
         return false;
       return true;
     }
@@ -49,8 +56,7 @@ static bool needDoubleDescriptor(OpTy op) {
                  op.getBox().getDefiningOp())) {
     if (mlir::isa_and_nonnull<fir::AddrOfOp>(
             declareOp.getMemref().getDefiningOp())) {
-      if (declareOp.getDataAttr() &&
-          *declareOp.getDataAttr() == cuf::DataAttribute::Pinned)
+      if (isPinned(declareOp))
         return false;
       return true;
     }
@@ -108,17 +114,22 @@ struct CufAllocateOpConversion
     if (op.getPinned())
       return mlir::failure();
 
-    // TODO: Allocation of module variable will need more work as the descriptor
-    // will be duplicated and needs to be synced after allocation.
-    if (needDoubleDescriptor(op))
-      return mlir::failure();
+    auto mod = op->getParentOfType<mlir::ModuleOp>();
+    fir::FirOpBuilder builder(rewriter, mod);
+    mlir::Location loc = op.getLoc();
+
+    if (hasDoubleDescriptors(op)) {
+      // Allocation for module variable are done with custom runtime entry point
+      // so the descriptors can be synchronized.
+      mlir::func::FuncOp func =
+          fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocate)>(
+              loc, builder);
+      return convertOpToCall(op, rewriter, func);
+    }
 
     // Allocation for local descriptor falls back on the standard runtime
     // AllocatableAllocate as the dedicated allocator is set in the descriptor
     // before the call.
-    auto mod = op->template getParentOfType<mlir::ModuleOp>();
-    fir::FirOpBuilder builder(rewriter, mod);
-    mlir::Location loc = op.getLoc();
     mlir::func::FuncOp func =
         fir::runtime::getRuntimeFunc<mkRTKey(AllocatableAllocate)>(loc,
                                                                    builder);
@@ -133,17 +144,23 @@ struct CufDeallocateOpConversion
   mlir::LogicalResult
   matchAndRewrite(cuf::DeallocateOp op,
                   mlir::PatternRewriter &rewriter) const override {
-    // TODO: Allocation of module variable will need more work as the descriptor
-    // will be duplicated and needs to be synced after allocation.
-    if (needDoubleDescriptor(op))
-      return mlir::failure();
 
-    // Deallocation for local descriptor falls back on the standard runtime
-    // AllocatableDeallocate as the dedicated deallocator is set in the
-    // descriptor before the call.
     auto mod = op->getParentOfType<mlir::ModuleOp>();
     fir::FirOpBuilder builder(rewriter, mod);
     mlir::Location loc = op.getLoc();
+
+    if (hasDoubleDescriptors(op)) {
+      // Deallocation for module variable are done with custom runtime entry
+      // point so the descriptors can be synchronized.
+      mlir::func::FuncOp func =
+          fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableDeallocate)>(
+              loc, builder);
+      return convertOpToCall(op, rewriter, func);
+    }
+
+    // Deallocation for local descriptor falls back on the standard runtime
+    // AllocatableDeallocate as the dedicated deallocator is set in the
+    // descriptor before the call.
     mlir::func::FuncOp func =
         fir::runtime::getRuntimeFunc<mkRTKey(AllocatableDeallocate)>(loc,
                                                                      builder);
@@ -448,10 +465,6 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
       }
       return true;
     });
-    target.addDynamicallyLegalOp<cuf::AllocateOp>(
-        [](::cuf::AllocateOp op) { return needDoubleDescriptor(op); });
-    target.addDynamicallyLegalOp<cuf::DeallocateOp>(
-        [](::cuf::DeallocateOp op) { return needDoubleDescriptor(op); });
     target.addDynamicallyLegalOp<cuf::DataTransferOp>(
         [](::cuf::DataTransferOp op) {
           mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
diff --git a/flang/test/Fir/CUDA/cuda-allocate.fir b/flang/test/Fir/CUDA/cuda-allocate.fir
index 1c17e7447e5c97..65c68bb69301af 100644
--- a/flang/test/Fir/CUDA/cuda-allocate.fir
+++ b/flang/test/Fir/CUDA/cuda-allocate.fir
@@ -54,8 +54,14 @@ func.func @_QPsub3() {
 }
 
 // CHECK-LABEL: func.func @_QPsub3()
-// CHECK: cuf.allocate
-// CHECK: cuf.deallocate
+// CHECK: %[[A_ADDR:.*]] = fir.address_of(@_QMmod1Ea) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+// CHECK: %[[A:.*]]:2 = hlfir.declare %[[A_ADDR]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+
+// CHECK: %[[A_BOX:.*]] = fir.convert %[[A]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFAllocatableAllocate(%[[A_BOX]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
+
+// CHECK: %[[A_BOX:.*]] = fir.convert %[[A]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFAllocatableDeallocate(%[[A_BOX]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
 
 func.func @_QPsub4() attributes {cuf.proc_attr = #cuf.cuda_proc<device>} {
   %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Ea"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
@@ -95,4 +101,34 @@ func.func @_QPsub5() {
 // CHECK: fir.call @_FortranAAllocatableAllocate({{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
 // CHECK: fir.call @_FortranAAllocatableDeallocate({{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
 
+
+fir.global @_QMdataEb {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xi32>>> {
+  %c0 = arith.constant 0 : index
+  %0 = fir.zero_bits !fir.heap<!fir.array<?xi32>>
+  %1 = fir.shape %c0 : (index) -> !fir.shape<1>
+  %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
+  fir.has_value %2 : !fir.box<!fir.heap<!fir.array<?xi32>>>
+}
+
+func.func @_QQsub6() attributes {fir.bindc_name = "test"} {
+  %c0_i32 = arith.constant 0 : i32
+  %c10_i32 = arith.constant 10 : i32
+  %c1 = arith.constant 1 : index
+  %0 = fir.address_of(@_QMdataEb) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  %1:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMdataEb"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+  %2 = fir.convert %1#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
+  %3 = fir.convert %c1 : (index) -> i64
+  %4 = fir.convert %c10_i32 : (i32) -> i64
+  %5 = fir.call @_FortranAAllocatableSetBounds(%2, %c0_i32, %3, %4) fastmath<contract> : (!fir.ref<!fir.box<none>>, i32, i64, i64) -> none
+  %6 = cuf.allocate %1#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>} -> i32
+  return
+}
+
+// CHECK-LABEL: func.func @_QQsub6() attributes {fir.bindc_name = "test"}
+// CHECK: %[[B_ADDR:.*]] = fir.address_of(@_QMdataEb) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// CHECK: %[[B:.*]]:2 = hlfir.declare %[[B_ADDR]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMdataEb"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+// CHECK: _FortranAAllocatableSetBounds
+// CHECK: %[[B_BOX:.*]] = fir.convert %[[B]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFAllocatableAllocate(%[[B_BOX]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
+
 } // end of module



More information about the llvm-branch-commits mailing list