[flang-commits] [flang] [flang][cuda] Add conversion pass for cuf.allocate (PR #101563)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Thu Aug 1 14:24:34 PDT 2024


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/101563

Allocator can be specified in the descriptor. For simple local allocatable, we can simply convert `cuf.allocate`/`cuf.deallocate` to their corresponding runtime calls in the standard flang runtime. More specific cases will require dedicated entry points. Global descriptor will require sync between host and device copy. 

This patch adds a pass to perform this conversion.  

>From 54e314bd7953579641c78a3845a2df10e6841a6a Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 16 Jul 2024 15:49:30 -0700
Subject: [PATCH] [flang][cuda] Add conversion pass for cuf.allocate

---
 .../flang/Optimizer/Transforms/Passes.h       |   1 +
 .../flang/Optimizer/Transforms/Passes.td      |   7 +
 flang/lib/Optimizer/Transforms/CMakeLists.txt |   1 +
 .../Optimizer/Transforms/CufOpConversion.cpp  | 154 ++++++++++++++++++
 flang/test/Fir/CUDA/cuda-allocate.fir         |  21 +++
 5 files changed, 184 insertions(+)
 create mode 100644 flang/lib/Optimizer/Transforms/CufOpConversion.cpp
 create mode 100644 flang/test/Fir/CUDA/cuda-allocate.fir

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index df709645c01b0..96b0e9714b95a 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -39,6 +39,7 @@ namespace fir {
 #define GEN_PASS_DECL_ASSUMEDRANKOPCONVERSION
 #define GEN_PASS_DECL_CHARACTERCONVERSION
 #define GEN_PASS_DECL_CFGCONVERSION
+#define GEN_PASS_DECL_CUFOPCONVERSION
 #define GEN_PASS_DECL_EXTERNALNAMECONVERSION
 #define GEN_PASS_DECL_MEMREFDATAFLOWOPT
 #define GEN_PASS_DECL_SIMPLIFYINTRINSICS
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 786083f95e15c..c703a62c03b7d 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -429,4 +429,11 @@ def AssumedRankOpConversion : Pass<"fir-assumed-rank-op", "mlir::ModuleOp"> {
   ];
 }
 
+def CufOpConversion : Pass<"cuf-convert", "mlir::ModuleOp"> {
+  let summary = "Convert some CUF operations to runtime calls";
+  let dependentDialects = [
+    "fir::FIROpsDialect"
+  ];
+}
+
 #endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 3108304240894..5306b84e0e77a 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_flang_library(FIRTransforms
   CharacterConversion.cpp
   ConstantArgumentGlobalisation.cpp
   ControlFlowConverter.cpp
+  CufOpConversion.cpp
   ArrayValueCopy.cpp
   ExternalNameConversion.cpp
   MemoryUtils.cpp
diff --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
new file mode 100644
index 0000000000000..a81f3172bddba
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
@@ -0,0 +1,154 @@
+//===-- CufOpConversion.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Common/Fortran.h"
+#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Runtime/allocatable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace fir {
+#define GEN_PASS_DEF_CUFOPCONVERSION
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace fir;
+using namespace mlir;
+using namespace Fortran::runtime;
+
+namespace {
+
+template <typename OpTy>
+static bool isBoxGlobal(OpTy op) {
+  if (auto declareOp =
+          mlir::dyn_cast<fir::DeclareOp>(op.getBox().getDefiningOp())) {
+    if (mlir::isa<fir::AddrOfOp>(declareOp.getMemref().getDefiningOp()))
+      return true;
+  } else if (auto declareOp = mlir::dyn_cast<hlfir::DeclareOp>(
+                 op.getBox().getDefiningOp())) {
+    if (mlir::isa<fir::AddrOfOp>(declareOp.getMemref().getDefiningOp()))
+      return true;
+  }
+  return false;
+}
+
+template <typename OpTy>
+static mlir::LogicalResult convertOpToCall(OpTy op,
+                                           mlir::PatternRewriter &rewriter,
+                                           mlir::func::FuncOp func) {
+  auto mod = op->template getParentOfType<mlir::ModuleOp>();
+  fir::FirOpBuilder builder(rewriter, mod);
+  mlir::Location loc = op.getLoc();
+  auto fTy = func.getFunctionType();
+
+  mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+  mlir::Value sourceLine =
+      fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
+
+  mlir::Value hasStat = op.getHasStat() ? builder.createBool(loc, true)
+                                        : builder.createBool(loc, false);
+
+  mlir::Value errmsg;
+  if (op.getErrmsg()) {
+    errmsg = op.getErrmsg();
+  } else {
+    mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType());
+    errmsg = builder.create<fir::AbsentOp>(loc, boxNoneTy).getResult();
+  }
+  llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+      builder, loc, fTy, op.getBox(), hasStat, errmsg, sourceFile, sourceLine)};
+  auto callOp = builder.create<fir::CallOp>(loc, func, args);
+  rewriter.replaceOp(op, callOp);
+  return mlir::success();
+}
+
+struct CufAllocateOpConversion
+    : public mlir::OpRewritePattern<cuf::AllocateOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(cuf::AllocateOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    // TODO: Allocation with source will need a new entry point in the runtime.
+    if (op.getSource())
+      return mlir::failure();
+
+    // TODO: Allocation using different stream.
+    if (op.getStream())
+      return mlir::failure();
+
+    // TODO: Pinned is a reference to a logical value that can be set to true
+    // when pinned allocation succeed. This will require a new entry point.
+    if (op.getPinned())
+      return mlir::failure();
+
+    // TODO: Allocation of module variable will need more work as the descriptor
+    // will be duplicated and needs to be synced after allocation.
+    if (isBoxGlobal(op))
+      return mlir::failure();
+
+    // Allocation for local descriptor falls back on the standard runtime
+    // AllocatableAllocate as the dedicated allocator is set in the descriptor
+    // before the call.
+    auto mod = op->template getParentOfType<mlir::ModuleOp>();
+    fir::FirOpBuilder builder(rewriter, mod);
+    mlir::Location loc = op.getLoc();
+    mlir::func::FuncOp func =
+        fir::runtime::getRuntimeFunc<mkRTKey(AllocatableAllocate)>(loc,
+                                                                   builder);
+    return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
+  }
+};
+
+struct CufDeallocateOpConversion
+    : public mlir::OpRewritePattern<cuf::DeallocateOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(cuf::DeallocateOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    // TODO: Allocation of module variable will need more work as the descriptor
+    // will be duplicated and needs to be synced after allocation.
+    if (isBoxGlobal(op))
+      return mlir::failure();
+
+    // Deallocation for local descriptor falls back on the standard runtime
+    // AllocatableDeallocate as the dedicated deallocator is set in the
+    // descriptor before the call.
+    auto mod = op->getParentOfType<mlir::ModuleOp>();
+    fir::FirOpBuilder builder(rewriter, mod);
+    mlir::Location loc = op.getLoc();
+    mlir::func::FuncOp func =
+        fir::runtime::getRuntimeFunc<mkRTKey(AllocatableDeallocate)>(loc,
+                                                                     builder);
+    return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
+  }
+};
+
+class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
+public:
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    mlir::RewritePatternSet patterns(ctx);
+    mlir::ConversionTarget target(*ctx);
+    target.addIllegalOp<cuf::AllocateOp, cuf::DeallocateOp>();
+    patterns.insert<CufAllocateOpConversion, CufDeallocateOpConversion>(ctx);
+    if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
+                                                  std::move(patterns)))) {
+      mlir::emitError(mlir::UnknownLoc::get(ctx),
+                      "error in CUF op conversion\n");
+      signalPassFailure();
+    }
+  }
+};
+} // namespace
diff --git a/flang/test/Fir/CUDA/cuda-allocate.fir b/flang/test/Fir/CUDA/cuda-allocate.fir
new file mode 100644
index 0000000000000..ab4a253f33dd8
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-allocate.fir
@@ -0,0 +1,21 @@
+// RUN: fir-opt --cuf-convert %s | FileCheck %s
+
+func.func @_QPsub1() {
+  %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Ea"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+  %4:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %c1 = arith.constant 1 : index
+  %c10_i32 = arith.constant 10 : i32
+  %c0_i32 = arith.constant 0 : i32
+  %9 = cuf.allocate %4#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>} -> i32
+  %10 = cuf.deallocate %4#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>} -> i32
+  return
+}
+
+// CHECK-LABEL: func.func @_QPsub1()
+// CHECK: %[[DESC:.*]] = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Ea"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+// CHECK: %[[DECL_DESC:.*]]:2 = hlfir.declare %[[DESC]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DECL_DESC]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: %{{.*}} = fir.call @_FortranAAllocatableAllocate(%[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
+
+// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DECL_DESC]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: %{{.*}} = fir.call @_FortranAAllocatableDeallocate(%[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32



More information about the flang-commits mailing list