[flang-commits] [flang] [flang][cuda] Convert cuf.alloc and cuf.free for scalar and arrays (PR #110055)
via flang-commits
flang-commits at lists.llvm.org
Wed Sep 25 15:39:07 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
This patch adds more conversion of cuf.alloc and cuf.free for scalars, constant size arrays and dynamic size arrays
---
Full diff: https://github.com/llvm/llvm-project/pull/110055.diff
3 Files Affected:
- (modified) flang/lib/Optimizer/Transforms/CufOpConversion.cpp (+92-49)
- (added) flang/test/Fir/CUDA/cuda-alloc-free.fir (+64)
- (modified) flang/test/Fir/CUDA/cuda-allocate.fir (-11)
``````````diff
diff --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
index f8ace2dd96a0d8..fff026bbda2d40 100644
--- a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
@@ -183,6 +183,29 @@ static bool inDeviceContext(mlir::Operation *op) {
return false;
}
+static int computeWidth(mlir::Location loc, mlir::Type type,
+ fir::KindMapping &kindMap) {
+ auto eleTy = fir::unwrapSequenceType(type);
+ int width = 0;
+ if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) {
+ width = t.getWidth() / 8;
+ } else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) {
+ width = t.getWidth() / 8;
+ } else if (eleTy.isInteger(1)) {
+ width = 1;
+ } else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) {
+ int kind = t.getFKind();
+ width = kindMap.getLogicalBitsize(kind) / 8;
+ } else if (auto t{mlir::dyn_cast<fir::ComplexType>(eleTy)}) {
+ int kind = t.getFKind();
+ int elemSize = kindMap.getRealBitsize(kind) / 8;
+ width = 2 * elemSize;
+ } else {
+ llvm::report_fatal_error("unsupported type");
+ }
+ return width;
+}
+
struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
using OpRewritePattern::OpRewritePattern;
@@ -193,11 +216,6 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
mlir::LogicalResult
matchAndRewrite(cuf::AllocOp op,
mlir::PatternRewriter &rewriter) const override {
- auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
-
- // Only convert cuf.alloc that allocates a descriptor.
- if (!boxTy)
- return failure();
if (inDeviceContext(op.getOperation())) {
// In device context just replace the cuf.alloc operation with a fir.alloc
@@ -212,11 +230,56 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
auto mod = op->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Location loc = op.getLoc();
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+
+ if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) {
+ // Convert scalar and known size array allocations.
+ mlir::Value bytes;
+ fir::KindMapping kindMap{fir::getKindMapping(mod)};
+ if (fir::isa_trivial(op.getInType())) {
+ int width = computeWidth(loc, op.getInType(), kindMap);
+ bytes =
+ builder.createIntegerConstant(loc, builder.getIndexType(), width);
+ } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
+ op.getInType())) {
+ mlir::Value width = builder.createIntegerConstant(
+ loc, builder.getIndexType(),
+ computeWidth(loc, seqTy.getEleTy(), kindMap));
+ mlir::Value nbElem;
+ if (fir::sequenceWithNonConstantShape(seqTy)) {
+ assert(!op.getShape().empty() && "expect shape with dynamic arrays");
+ nbElem = builder.loadIfRef(loc, op.getShape()[0]);
+ for (unsigned i = 1; i < op.getShape().size(); ++i) {
+ nbElem = rewriter.create<mlir::arith::MulIOp>(
+ loc, nbElem, builder.loadIfRef(loc, op.getShape()[i]));
+ }
+ } else {
+ nbElem = builder.createIntegerConstant(loc, builder.getIndexType(),
+ seqTy.getConstantArraySize());
+ }
+ bytes = rewriter.create<mlir::arith::MulIOp>(loc, nbElem, width);
+ }
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder);
+ auto fTy = func.getFunctionType();
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
+ mlir::Value memTy = builder.createIntegerConstant(
+ loc, builder.getI32Type(), kMemTypeDevice);
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)};
+ auto callOp = builder.create<fir::CallOp>(loc, func, args);
+ auto convOp = builder.createConvert(loc, op.getResult().getType(),
+ callOp.getResult(0));
+ rewriter.replaceOp(op, convOp);
+ return mlir::success();
+ }
+
+ // Convert descriptor allocations to function call.
+ auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDesciptor)>(loc, builder);
-
auto fTy = func.getFunctionType();
- mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
@@ -245,26 +308,39 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
mlir::LogicalResult
matchAndRewrite(cuf::FreeOp op,
mlir::PatternRewriter &rewriter) const override {
- // Only convert cuf.free on descriptor.
- if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
- return failure();
- auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
- if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy()))
- return failure();
-
if (inDeviceContext(op.getOperation())) {
rewriter.eraseOp(op);
return mlir::success();
}
+ if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
+ return failure();
+
auto mod = op->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Location loc = op.getLoc();
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+
+ auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
+ if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy())) {
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFMemFree)>(loc, builder);
+ auto fTy = func.getFunctionType();
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
+ mlir::Value memTy = builder.createIntegerConstant(
+ loc, builder.getI32Type(), kMemTypeDevice);
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)};
+ builder.create<fir::CallOp>(loc, func, args);
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+
+ // Convert cuf.free on descriptors.
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDesciptor)>(loc, builder);
-
auto fTy = func.getFunctionType();
- mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
@@ -275,29 +351,6 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
}
};
-static int computeWidth(mlir::Location loc, mlir::Type type,
- fir::KindMapping &kindMap) {
- auto eleTy = fir::unwrapSequenceType(type);
- int width = 0;
- if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) {
- width = t.getWidth() / 8;
- } else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) {
- width = t.getWidth() / 8;
- } else if (eleTy.isInteger(1)) {
- width = 1;
- } else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) {
- int kind = t.getFKind();
- width = kindMap.getLogicalBitsize(kind) / 8;
- } else if (auto t{mlir::dyn_cast<fir::ComplexType>(eleTy)}) {
- int kind = t.getFKind();
- int elemSize = kindMap.getRealBitsize(kind) / 8;
- width = 2 * elemSize;
- } else {
- llvm::report_fatal_error("unsupported type");
- }
- return width;
-}
-
static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::Type toTy,
mlir::Value val) {
@@ -456,16 +509,6 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false);
fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
/*forceUnifiedTBAATree=*/false, *dl);
- target.addDynamicallyLegalOp<cuf::AllocOp>([](::cuf::AllocOp op) {
- return !mlir::isa<fir::BaseBoxType>(op.getInType());
- });
- target.addDynamicallyLegalOp<cuf::FreeOp>([](::cuf::FreeOp op) {
- if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(
- op.getDevptr().getType())) {
- return !mlir::isa<fir::BaseBoxType>(refTy.getEleTy());
- }
- return true;
- });
target.addDynamicallyLegalOp<cuf::DataTransferOp>(
[](::cuf::DataTransferOp op) {
mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
diff --git a/flang/test/Fir/CUDA/cuda-alloc-free.fir b/flang/test/Fir/CUDA/cuda-alloc-free.fir
new file mode 100644
index 00000000000000..25821418a40f11
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-alloc-free.fir
@@ -0,0 +1,64 @@
+// RUN: fir-opt --cuf-convert %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
+
+func.func @_QPsub1() {
+ %0 = cuf.alloc i32 {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eidev"} -> !fir.ref<i32>
+ %1:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eidev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ cuf.free %1#1 : !fir.ref<i32> {data_attr = #cuf.cuda<device>}
+ return
+}
+
+// CHECK-LABEL: func.func @_QPsub1()
+// CHECK: %[[BYTES:.*]] = fir.convert %c4{{.*}} : (index) -> i64
+// CHECK: %[[ALLOC:.*]] = fir.call @_FortranACUFMemAlloc(%[[BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[CONV:.*]] = fir.convert %3 : (!fir.llvm_ptr<i8>) -> !fir.ref<i32>
+// CHECK: %[[DECL:.*]]:2 = hlfir.declare %[[CONV]] {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eidev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK: %[[DEVPTR:.*]] = fir.convert %[[DECL]]#1 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFMemFree(%[[DEVPTR]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, i32, !fir.ref<i8>, i32) -> none
+
+func.func @_QPsub2() {
+ %0 = cuf.alloc !fir.array<10xf32> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} -> !fir.ref<!fir.array<10xf32>>
+ cuf.free %0 : !fir.ref<!fir.array<10xf32>> {data_attr = #cuf.cuda<device>}
+ return
+}
+
+// CHECK-LABEL: func.func @_QPsub2()
+// CHECK: %[[BYTES:.*]] = arith.muli %c10{{.*}}, %c4{{.*}} : index
+// CHECK: %[[CONV_BYTES:.*]] = fir.convert %[[BYTES]] : (index) -> i64
+// CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFMemFree
+
+func.func @_QPsub3(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<i32> {fir.bindc_name = "m"}) {
+ %0 = fir.dummy_scope : !fir.dscope
+ %1:2 = hlfir.declare %arg0 dummy_scope %0 {uniq_name = "_QFsub3En"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %2:2 = hlfir.declare %arg1 dummy_scope %0 {uniq_name = "_QFsub3Em"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %3 = fir.load %1#0 : !fir.ref<i32>
+ %4 = fir.convert %3 : (i32) -> i64
+ %5 = fir.convert %4 : (i64) -> index
+ %c0 = arith.constant 0 : index
+ %6 = arith.cmpi sgt, %5, %c0 : index
+ %7 = arith.select %6, %5, %c0 : index
+ %8 = fir.load %2#0 : !fir.ref<i32>
+ %9 = fir.convert %8 : (i32) -> i64
+ %10 = fir.convert %9 : (i64) -> index
+ %c0_0 = arith.constant 0 : index
+ %11 = arith.cmpi sgt, %10, %c0_0 : index
+ %12 = arith.select %11, %10, %c0_0 : index
+ %13 = cuf.alloc !fir.array<?x?xi32>, %7, %12 : index, index {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eidev"} -> !fir.ref<!fir.array<?x?xi32>>
+ %14 = fir.shape %7, %12 : (index, index) -> !fir.shape<2>
+ %15:2 = hlfir.declare %13(%14) {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eidev"} : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.ref<!fir.array<?x?xi32>>)
+ cuf.free %15#1 : !fir.ref<!fir.array<?x?xi32>> {data_attr = #cuf.cuda<device>}
+ return
+}
+
+// CHECK-LABEL: func.func @_QPsub3
+// CHECK: %[[N:.*]] = arith.select
+// CHECK: %[[M:.*]] = arith.select
+// CHECK: %[[NBELEM:.*]] = arith.muli %[[N]], %[[M]] : index
+// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %c4{{.*}} : index
+// CHECK: %[[CONV_BYTES:.*]] = fir.convert %[[BYTES]] : (index) -> i64
+// CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFMemFree
+
+} // end module
diff --git a/flang/test/Fir/CUDA/cuda-allocate.fir b/flang/test/Fir/CUDA/cuda-allocate.fir
index 65c68bb69301af..d68ff894d5af5a 100644
--- a/flang/test/Fir/CUDA/cuda-allocate.fir
+++ b/flang/test/Fir/CUDA/cuda-allocate.fir
@@ -26,17 +26,6 @@ func.func @_QPsub1() {
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DECL_DESC]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFFreeDesciptor(%[[BOX_NONE]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<i8>, i32) -> none
-// Check operations that should not be transformed yet.
-func.func @_QPsub2() {
- %0 = cuf.alloc !fir.array<10xf32> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} -> !fir.ref<!fir.array<10xf32>>
- cuf.free %0 : !fir.ref<!fir.array<10xf32>> {data_attr = #cuf.cuda<device>}
- return
-}
-
-// CHECK-LABEL: func.func @_QPsub2()
-// CHECK: cuf.alloc !fir.array<10xf32>
-// CHECK: cuf.free %{{.*}} : !fir.ref<!fir.array<10xf32>>
-
fir.global @_QMmod1Ea {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xf32>>> {
%0 = fir.zero_bits !fir.heap<!fir.array<?xf32>>
%c0 = arith.constant 0 : index
``````````
</details>
https://github.com/llvm/llvm-project/pull/110055
More information about the flang-commits
mailing list