[flang-commits] [flang] [flang][cuda] Add conversion pass for cuf.allocate (PR #101563)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Thu Aug 1 14:24:34 PDT 2024
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/101563
Allocator can be specified in the descriptor. For simple local allocatable, we can simply convert `cuf.allocate`/`cuf.deallocate` to their corresponding runtime calls in the standard flang runtime. More specific cases will require dedicated entry points. Global descriptor will require sync between host and device copy.
This patch adds a pass to perform this conversion.
>From 54e314bd7953579641c78a3845a2df10e6841a6a Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 16 Jul 2024 15:49:30 -0700
Subject: [PATCH] [flang][cuda] Add conversion pass for cuf.allocate
---
.../flang/Optimizer/Transforms/Passes.h | 1 +
.../flang/Optimizer/Transforms/Passes.td | 7 +
flang/lib/Optimizer/Transforms/CMakeLists.txt | 1 +
.../Optimizer/Transforms/CufOpConversion.cpp | 154 ++++++++++++++++++
flang/test/Fir/CUDA/cuda-allocate.fir | 21 +++
5 files changed, 184 insertions(+)
create mode 100644 flang/lib/Optimizer/Transforms/CufOpConversion.cpp
create mode 100644 flang/test/Fir/CUDA/cuda-allocate.fir
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index df709645c01b0..96b0e9714b95a 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -39,6 +39,7 @@ namespace fir {
#define GEN_PASS_DECL_ASSUMEDRANKOPCONVERSION
#define GEN_PASS_DECL_CHARACTERCONVERSION
#define GEN_PASS_DECL_CFGCONVERSION
+#define GEN_PASS_DECL_CUFOPCONVERSION
#define GEN_PASS_DECL_EXTERNALNAMECONVERSION
#define GEN_PASS_DECL_MEMREFDATAFLOWOPT
#define GEN_PASS_DECL_SIMPLIFYINTRINSICS
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 786083f95e15c..c703a62c03b7d 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -429,4 +429,11 @@ def AssumedRankOpConversion : Pass<"fir-assumed-rank-op", "mlir::ModuleOp"> {
];
}
+def CufOpConversion : Pass<"cuf-convert", "mlir::ModuleOp"> {
+ let summary = "Convert some CUF operations to runtime calls";
+ let dependentDialects = [
+ "fir::FIROpsDialect"
+ ];
+}
+
#endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 3108304240894..5306b84e0e77a 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_flang_library(FIRTransforms
CharacterConversion.cpp
ConstantArgumentGlobalisation.cpp
ControlFlowConverter.cpp
+ CufOpConversion.cpp
ArrayValueCopy.cpp
ExternalNameConversion.cpp
MemoryUtils.cpp
diff --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
new file mode 100644
index 0000000000000..a81f3172bddba
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
@@ -0,0 +1,154 @@
+//===-- CufOpConversion.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Common/Fortran.h"
+#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Runtime/allocatable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace fir {
+#define GEN_PASS_DEF_CUFOPCONVERSION
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace fir;
+using namespace mlir;
+using namespace Fortran::runtime;
+
+namespace {
+
+template <typename OpTy>
+static bool isBoxGlobal(OpTy op) {
+ if (auto declareOp =
+ mlir::dyn_cast<fir::DeclareOp>(op.getBox().getDefiningOp())) {
+ if (mlir::isa<fir::AddrOfOp>(declareOp.getMemref().getDefiningOp()))
+ return true;
+ } else if (auto declareOp = mlir::dyn_cast<hlfir::DeclareOp>(
+ op.getBox().getDefiningOp())) {
+ if (mlir::isa<fir::AddrOfOp>(declareOp.getMemref().getDefiningOp()))
+ return true;
+ }
+ return false;
+}
+
+template <typename OpTy>
+static mlir::LogicalResult convertOpToCall(OpTy op,
+ mlir::PatternRewriter &rewriter,
+ mlir::func::FuncOp func) {
+ auto mod = op->template getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+ auto fTy = func.getFunctionType();
+
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
+
+ mlir::Value hasStat = op.getHasStat() ? builder.createBool(loc, true)
+ : builder.createBool(loc, false);
+
+ mlir::Value errmsg;
+ if (op.getErrmsg()) {
+ errmsg = op.getErrmsg();
+ } else {
+ mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType());
+ errmsg = builder.create<fir::AbsentOp>(loc, boxNoneTy).getResult();
+ }
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, op.getBox(), hasStat, errmsg, sourceFile, sourceLine)};
+ auto callOp = builder.create<fir::CallOp>(loc, func, args);
+ rewriter.replaceOp(op, callOp);
+ return mlir::success();
+}
+
+struct CufAllocateOpConversion
+ : public mlir::OpRewritePattern<cuf::AllocateOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cuf::AllocateOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ // TODO: Allocation with source will need a new entry point in the runtime.
+ if (op.getSource())
+ return mlir::failure();
+
+ // TODO: Allocation using different stream.
+ if (op.getStream())
+ return mlir::failure();
+
+ // TODO: Pinned is a reference to a logical value that can be set to true
+ // when pinned allocation succeed. This will require a new entry point.
+ if (op.getPinned())
+ return mlir::failure();
+
+ // TODO: Allocation of module variable will need more work as the descriptor
+ // will be duplicated and needs to be synced after allocation.
+ if (isBoxGlobal(op))
+ return mlir::failure();
+
+ // Allocation for local descriptor falls back on the standard runtime
+ // AllocatableAllocate as the dedicated allocator is set in the descriptor
+ // before the call.
+ auto mod = op->template getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(AllocatableAllocate)>(loc,
+ builder);
+ return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
+ }
+};
+
+struct CufDeallocateOpConversion
+ : public mlir::OpRewritePattern<cuf::DeallocateOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cuf::DeallocateOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ // TODO: Allocation of module variable will need more work as the descriptor
+ // will be duplicated and needs to be synced after allocation.
+ if (isBoxGlobal(op))
+ return mlir::failure();
+
+ // Deallocation for local descriptor falls back on the standard runtime
+ // AllocatableDeallocate as the dedicated deallocator is set in the
+ // descriptor before the call.
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(AllocatableDeallocate)>(loc,
+ builder);
+ return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
+ }
+};
+
+class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
+public:
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ mlir::RewritePatternSet patterns(ctx);
+ mlir::ConversionTarget target(*ctx);
+ target.addIllegalOp<cuf::AllocateOp, cuf::DeallocateOp>();
+ patterns.insert<CufAllocateOpConversion, CufDeallocateOpConversion>(ctx);
+ if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
+ std::move(patterns)))) {
+ mlir::emitError(mlir::UnknownLoc::get(ctx),
+ "error in CUF op conversion\n");
+ signalPassFailure();
+ }
+ }
+};
+} // namespace
diff --git a/flang/test/Fir/CUDA/cuda-allocate.fir b/flang/test/Fir/CUDA/cuda-allocate.fir
new file mode 100644
index 0000000000000..ab4a253f33dd8
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-allocate.fir
@@ -0,0 +1,21 @@
+// RUN: fir-opt --cuf-convert %s | FileCheck %s
+
+func.func @_QPsub1() {
+ %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Ea"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+ %4:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+ %c1 = arith.constant 1 : index
+ %c10_i32 = arith.constant 10 : i32
+ %c0_i32 = arith.constant 0 : i32
+ %9 = cuf.allocate %4#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>} -> i32
+ %10 = cuf.deallocate %4#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>} -> i32
+ return
+}
+
+// CHECK-LABEL: func.func @_QPsub1()
+// CHECK: %[[DESC:.*]] = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Ea"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+// CHECK: %[[DECL_DESC:.*]]:2 = hlfir.declare %[[DESC]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DECL_DESC]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: %{{.*}} = fir.call @_FortranAAllocatableAllocate(%[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
+
+// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DECL_DESC]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: %{{.*}} = fir.call @_FortranAAllocatableDeallocate(%[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
More information about the flang-commits
mailing list