[flang-commits] [flang] 21e64d1 - [flang][cuda][NFC] Split allocation related operation conversion from other cuf operations (#169740)
via flang-commits
flang-commits at lists.llvm.org
Mon Dec 1 10:19:56 PST 2025
Author: Valentin Clement (バレンタイン クレメン)
Date: 2025-12-01T10:19:52-08:00
New Revision: 21e64d1f5a3dbf539eaf9c7ac160469e60222ba2
URL: https://github.com/llvm/llvm-project/commit/21e64d1f5a3dbf539eaf9c7ac160469e60222ba2
DIFF: https://github.com/llvm/llvm-project/commit/21e64d1f5a3dbf539eaf9c7ac160469e60222ba2.diff
LOG: [flang][cuda][NFC] Split allocation related operation conversion from other cuf operations (#169740)
Split AllocOp, FreeOp, AllocateOp and DeallocateOp from other
conversion. Patterns are currently added to the base CUFOpConversion
when the option is enabled.
This split is a pre-requisite to be more flexible where we do the
allocation related operations conversion in the pipeline.
Added:
flang/include/flang/Optimizer/Transforms/CUDA/CUFAllocationConversion.h
flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp
Modified:
flang/include/flang/Optimizer/Transforms/Passes.td
flang/lib/Optimizer/Transforms/CMakeLists.txt
flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Transforms/CUDA/CUFAllocationConversion.h b/flang/include/flang/Optimizer/Transforms/CUDA/CUFAllocationConversion.h
new file mode 100644
index 0000000000000..2a4eb1cdb27f0
--- /dev/null
+++ b/flang/include/flang/Optimizer/Transforms/CUDA/CUFAllocationConversion.h
@@ -0,0 +1,33 @@
+//===------- CUFAllocationConversion.h --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_OPTIMIZER_TRANSFORMS_CUDA_CUFALLOCATIONCONVERSION_H_
+#define FORTRAN_OPTIMIZER_TRANSFORMS_CUDA_CUFALLOCATIONCONVERSION_H_
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+namespace fir {
+class LLVMTypeConverter;
+}
+
+namespace mlir {
+class DataLayout;
+class SymbolTable;
+} // namespace mlir
+
+namespace cuf {
+
+/// Patterns that convert CUF operations to runtime calls.
+void populateCUFAllocationConversionPatterns(
+ const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
+ const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns);
+
+} // namespace cuf
+
+#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUDA_CUFALLOCATIONCONVERSION_H_
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index f5403ab6ff503..f50202784e2dc 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -470,11 +470,19 @@ def AssumedRankOpConversion : Pass<"fir-assumed-rank-op", "mlir::ModuleOp"> {
];
}
+def CUFAllocationConversion : Pass<"cuf-allocation-convert", "mlir::ModuleOp"> {
+ let summary = "Convert allocation related CUF operations to runtime calls";
+ let dependentDialects = ["fir::FIROpsDialect"];
+}
+
def CUFOpConversion : Pass<"cuf-convert", "mlir::ModuleOp"> {
let summary = "Convert some CUF operations to runtime calls";
let dependentDialects = [
"fir::FIROpsDialect", "mlir::gpu::GPUDialect", "mlir::DLTIDialect"
];
+ let options = [Option<"allocationConversion", "allocation-conversion", "bool",
+ /*default=*/"true",
+ "Convert allocation related operation with this pass">];
}
def CUFDeviceGlobal :
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 0388439f89a54..619f3adc67c85 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -9,6 +9,7 @@ add_flang_library(FIRTransforms
CompilerGeneratedNames.cpp
ConstantArgumentGlobalisation.cpp
ControlFlowConverter.cpp
+ CUDA/CUFAllocationConversion.cpp
CUFAddConstructor.cpp
CUFDeviceGlobal.cpp
CUFOpConversion.cpp
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp
new file mode 100644
index 0000000000000..0acdb24bf62b1
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp
@@ -0,0 +1,468 @@
+//===-- CUFAllocationConversion.cpp ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Transforms/CUDA/CUFAllocationConversion.h"
+#include "flang/Optimizer/Builder/CUFCommon.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h"
+#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
+#include "flang/Optimizer/CodeGen/TypeConverter.h"
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/Support/DataLayout.h"
+#include "flang/Runtime/CUDA/allocatable.h"
+#include "flang/Runtime/CUDA/common.h"
+#include "flang/Runtime/CUDA/descriptor.h"
+#include "flang/Runtime/CUDA/memory.h"
+#include "flang/Runtime/CUDA/pointer.h"
+#include "flang/Runtime/allocatable.h"
+#include "flang/Runtime/allocator-registry-consts.h"
+#include "flang/Support/Fortran.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace fir {
+#define GEN_PASS_DEF_CUFALLOCATIONCONVERSION
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace fir;
+using namespace mlir;
+using namespace Fortran::runtime;
+using namespace Fortran::runtime::cuda;
+
+namespace {
+
+template <typename OpTy>
+static bool isPinned(OpTy op) {
+ if (op.getDataAttr() && *op.getDataAttr() == cuf::DataAttribute::Pinned)
+ return true;
+ return false;
+}
+
+static inline unsigned getMemType(cuf::DataAttribute attr) {
+ if (attr == cuf::DataAttribute::Device)
+ return kMemTypeDevice;
+ if (attr == cuf::DataAttribute::Managed)
+ return kMemTypeManaged;
+ if (attr == cuf::DataAttribute::Pinned)
+ return kMemTypePinned;
+ if (attr == cuf::DataAttribute::Unified)
+ return kMemTypeUnified;
+ llvm_unreachable("unsupported memory type");
+}
+
+template <typename OpTy>
+static bool hasDoubleDescriptors(OpTy op) {
+ if (auto declareOp =
+ mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp())) {
+ if (mlir::isa_and_nonnull<fir::AddrOfOp>(
+ declareOp.getMemref().getDefiningOp())) {
+ if (isPinned(declareOp))
+ return false;
+ return true;
+ }
+ } else if (auto declareOp = mlir::dyn_cast_or_null<hlfir::DeclareOp>(
+ op.getBox().getDefiningOp())) {
+ if (mlir::isa_and_nonnull<fir::AddrOfOp>(
+ declareOp.getMemref().getDefiningOp())) {
+ if (isPinned(declareOp))
+ return false;
+ return true;
+ }
+ }
+ return false;
+}
+
+static bool inDeviceContext(mlir::Operation *op) {
+ if (op->getParentOfType<cuf::KernelOp>())
+ return true;
+ if (auto funcOp = op->getParentOfType<mlir::gpu::GPUFuncOp>())
+ return true;
+ if (auto funcOp = op->getParentOfType<mlir::gpu::LaunchOp>())
+ return true;
+ if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) {
+ if (auto cudaProcAttr =
+ funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
+ cuf::getProcAttrName())) {
+ return cudaProcAttr.getValue() != cuf::ProcAttribute::Host &&
+ cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice;
+ }
+ }
+ return false;
+}
+
+template <typename OpTy>
+static mlir::LogicalResult convertOpToCall(OpTy op,
+ mlir::PatternRewriter &rewriter,
+ mlir::func::FuncOp func) {
+ auto mod = op->template getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+ auto fTy = func.getFunctionType();
+
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+ mlir::Value sourceLine;
+ if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>)
+ sourceLine = fir::factory::locationToLineNo(
+ builder, loc, op.getSource() ? fTy.getInput(7) : fTy.getInput(6));
+ else
+ sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
+
+ mlir::Value hasStat = op.getHasStat() ? builder.createBool(loc, true)
+ : builder.createBool(loc, false);
+
+ mlir::Value errmsg;
+ if (op.getErrmsg()) {
+ errmsg = op.getErrmsg();
+ } else {
+ mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType());
+ errmsg = fir::AbsentOp::create(builder, loc, boxNoneTy).getResult();
+ }
+ llvm::SmallVector<mlir::Value> args;
+ if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>) {
+ mlir::Value pinned =
+ op.getPinned()
+ ? op.getPinned()
+ : builder.createNullConstant(
+ loc, fir::ReferenceType::get(
+ mlir::IntegerType::get(op.getContext(), 1)));
+ if (op.getSource()) {
+ mlir::Value stream =
+ op.getStream() ? op.getStream()
+ : builder.createNullConstant(loc, fTy.getInput(2));
+ args = fir::runtime::createArguments(
+ builder, loc, fTy, op.getBox(), op.getSource(), stream, pinned,
+ hasStat, errmsg, sourceFile, sourceLine);
+ } else {
+ mlir::Value stream =
+ op.getStream() ? op.getStream()
+ : builder.createNullConstant(loc, fTy.getInput(1));
+ args = fir::runtime::createArguments(builder, loc, fTy, op.getBox(),
+ stream, pinned, hasStat, errmsg,
+ sourceFile, sourceLine);
+ }
+ } else {
+ args =
+ fir::runtime::createArguments(builder, loc, fTy, op.getBox(), hasStat,
+ errmsg, sourceFile, sourceLine);
+ }
+ auto callOp = fir::CallOp::create(builder, loc, func, args);
+ rewriter.replaceOp(op, callOp);
+ return mlir::success();
+}
+
+struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ CUFAllocOpConversion(mlir::MLIRContext *context, mlir::DataLayout *dl,
+ const fir::LLVMTypeConverter *typeConverter)
+ : OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {}
+
+ mlir::LogicalResult
+ matchAndRewrite(cuf::AllocOp op,
+ mlir::PatternRewriter &rewriter) const override {
+
+ mlir::Location loc = op.getLoc();
+
+ if (inDeviceContext(op.getOperation())) {
+ // In device context just replace the cuf.alloc operation with a fir.alloc
+ // the cuf.free will be removed.
+ auto allocaOp =
+ fir::AllocaOp::create(rewriter, loc, op.getInType(),
+ op.getUniqName() ? *op.getUniqName() : "",
+ op.getBindcName() ? *op.getBindcName() : "",
+ op.getTypeparams(), op.getShape());
+ allocaOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
+ rewriter.replaceOp(op, allocaOp);
+ return mlir::success();
+ }
+
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+
+ if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) {
+ // Convert scalar and known size array allocations.
+ mlir::Value bytes;
+ fir::KindMapping kindMap{fir::getKindMapping(mod)};
+ if (fir::isa_trivial(op.getInType())) {
+ int width = cuf::computeElementByteSize(loc, op.getInType(), kindMap);
+ bytes =
+ builder.createIntegerConstant(loc, builder.getIndexType(), width);
+ } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
+ op.getInType())) {
+ std::size_t size = 0;
+ if (fir::isa_derived(seqTy.getEleTy())) {
+ mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy());
+ size = dl->getTypeSizeInBits(structTy) / 8;
+ } else {
+ size = cuf::computeElementByteSize(loc, seqTy.getEleTy(), kindMap);
+ }
+ mlir::Value width =
+ builder.createIntegerConstant(loc, builder.getIndexType(), size);
+ mlir::Value nbElem;
+ if (fir::sequenceWithNonConstantShape(seqTy)) {
+ assert(!op.getShape().empty() && "expect shape with dynamic arrays");
+ nbElem = builder.loadIfRef(loc, op.getShape()[0]);
+ for (unsigned i = 1; i < op.getShape().size(); ++i) {
+ nbElem = mlir::arith::MulIOp::create(
+ rewriter, loc, nbElem,
+ builder.loadIfRef(loc, op.getShape()[i]));
+ }
+ } else {
+ nbElem = builder.createIntegerConstant(loc, builder.getIndexType(),
+ seqTy.getConstantArraySize());
+ }
+ bytes = mlir::arith::MulIOp::create(rewriter, loc, nbElem, width);
+ } else if (fir::isa_derived(op.getInType())) {
+ mlir::Type structTy = typeConverter->convertType(op.getInType());
+ std::size_t structSize = dl->getTypeSizeInBits(structTy) / 8;
+ bytes = builder.createIntegerConstant(loc, builder.getIndexType(),
+ structSize);
+ } else if (fir::isa_char(op.getInType())) {
+ mlir::Type charTy = typeConverter->convertType(op.getInType());
+ std::size_t charSize = dl->getTypeSizeInBits(charTy) / 8;
+ bytes = builder.createIntegerConstant(loc, builder.getIndexType(),
+ charSize);
+ } else {
+ mlir::emitError(loc, "unsupported type in cuf.alloc\n");
+ }
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder);
+ auto fTy = func.getFunctionType();
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
+ mlir::Value memTy = builder.createIntegerConstant(
+ loc, builder.getI32Type(), getMemType(op.getDataAttr()));
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)};
+ auto callOp = fir::CallOp::create(builder, loc, func, args);
+ callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
+ auto convOp = builder.createConvert(loc, op.getResult().getType(),
+ callOp.getResult(0));
+ rewriter.replaceOp(op, convOp);
+ return mlir::success();
+ }
+
+ // Convert descriptor allocations to function call.
+ auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDescriptor)>(loc, builder);
+ auto fTy = func.getFunctionType();
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+
+ mlir::Type structTy = typeConverter->convertBoxTypeAsStruct(boxTy);
+ std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
+ mlir::Value sizeInBytes =
+ builder.createIntegerConstant(loc, builder.getIndexType(), boxSize);
+
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, sizeInBytes, sourceFile, sourceLine)};
+ auto callOp = fir::CallOp::create(builder, loc, func, args);
+ callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
+ auto convOp = builder.createConvert(loc, op.getResult().getType(),
+ callOp.getResult(0));
+ rewriter.replaceOp(op, convOp);
+ return mlir::success();
+ }
+
+private:
+ mlir::DataLayout *dl;
+ const fir::LLVMTypeConverter *typeConverter;
+};
+
+struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cuf::FreeOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ if (inDeviceContext(op.getOperation())) {
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+
+ if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
+ return failure();
+
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+
+ auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
+ if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy())) {
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFMemFree)>(loc, builder);
+ auto fTy = func.getFunctionType();
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
+ mlir::Value memTy = builder.createIntegerConstant(
+ loc, builder.getI32Type(), getMemType(op.getDataAttr()));
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)};
+ fir::CallOp::create(builder, loc, func, args);
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+
+ // Convert cuf.free on descriptors.
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDescriptor)>(loc, builder);
+ auto fTy = func.getFunctionType();
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, op.getDevptr(), sourceFile, sourceLine)};
+ auto callOp = fir::CallOp::create(builder, loc, func, args);
+ callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+};
+
+struct CUFAllocateOpConversion
+ : public mlir::OpRewritePattern<cuf::AllocateOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cuf::AllocateOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+
+ bool isPointer = false;
+
+ if (auto declareOp =
+ mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp()))
+ if (declareOp.getFortranAttrs() &&
+ bitEnumContainsAny(*declareOp.getFortranAttrs(),
+ fir::FortranVariableFlagsEnum::pointer))
+ isPointer = true;
+
+ if (hasDoubleDescriptors(op)) {
+ // Allocation for module variable are done with custom runtime entry point
+ // so the descriptors can be synchronized.
+ mlir::func::FuncOp func;
+ if (op.getSource()) {
+ func = isPointer ? fir::runtime::getRuntimeFunc<mkRTKey(
+ CUFPointerAllocateSourceSync)>(loc, builder)
+ : fir::runtime::getRuntimeFunc<mkRTKey(
+ CUFAllocatableAllocateSourceSync)>(loc, builder);
+ } else {
+ func =
+ isPointer
+ ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSync)>(
+ loc, builder)
+ : fir::runtime::getRuntimeFunc<mkRTKey(
+ CUFAllocatableAllocateSync)>(loc, builder);
+ }
+ return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
+ }
+
+ mlir::func::FuncOp func;
+ if (op.getSource()) {
+ func =
+ isPointer
+ ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSource)>(
+ loc, builder)
+ : fir::runtime::getRuntimeFunc<mkRTKey(
+ CUFAllocatableAllocateSource)>(loc, builder);
+ } else {
+ func =
+ isPointer
+ ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocate)>(
+ loc, builder)
+ : fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocate)>(
+ loc, builder);
+ }
+
+ return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
+ }
+};
+
+struct CUFDeallocateOpConversion
+ : public mlir::OpRewritePattern<cuf::DeallocateOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cuf::DeallocateOp op,
+ mlir::PatternRewriter &rewriter) const override {
+
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+
+ if (hasDoubleDescriptors(op)) {
+ // Deallocation for module variable are done with custom runtime entry
+ // point so the descriptors can be synchronized.
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableDeallocate)>(
+ loc, builder);
+ return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
+ }
+
+ // Deallocation for local descriptor falls back on the standard runtime
+ // AllocatableDeallocate as the dedicated deallocator is set in the
+ // descriptor before the call.
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(AllocatableDeallocate)>(loc,
+ builder);
+ return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
+ }
+};
+
+class CUFAllocationConversion
+ : public fir::impl::CUFAllocationConversionBase<CUFAllocationConversion> {
+public:
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ mlir::RewritePatternSet patterns(ctx);
+ mlir::ConversionTarget target(*ctx);
+
+ mlir::Operation *op = getOperation();
+ mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
+ if (!module)
+ return signalPassFailure();
+ mlir::SymbolTable symtab(module);
+
+ std::optional<mlir::DataLayout> dl = fir::support::getOrSetMLIRDataLayout(
+ module, /*allowDefaultLayout=*/false);
+ fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
+ /*forceUnifiedTBAATree=*/false, *dl);
+ target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
+ mlir::gpu::GPUDialect>();
+ target.addLegalOp<cuf::StreamCastOp>();
+ cuf::populateCUFAllocationConversionPatterns(typeConverter, *dl, symtab,
+ patterns);
+ if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
+ std::move(patterns)))) {
+ mlir::emitError(mlir::UnknownLoc::get(ctx),
+ "error in CUF allocation conversion\n");
+ signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+void cuf::populateCUFAllocationConversionPatterns(
+ const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
+ const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
+ patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter);
+ patterns.insert<CUFFreeOpConversion, CUFAllocateOpConversion,
+ CUFDeallocateOpConversion>(patterns.getContext());
+}
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 7ed34f865d0e9..f2ab99a8bc8ee 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -16,6 +16,8 @@
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/Support/DataLayout.h"
+#include "flang/Optimizer/Transforms/CUDA/CUFAllocationConversion.h"
+#include "flang/Optimizer/Transforms/Passes.h"
#include "flang/Runtime/CUDA/allocatable.h"
#include "flang/Runtime/CUDA/common.h"
#include "flang/Runtime/CUDA/descriptor.h"
@@ -44,207 +46,6 @@ using namespace Fortran::runtime::cuda;
namespace {
-static inline unsigned getMemType(cuf::DataAttribute attr) {
- if (attr == cuf::DataAttribute::Device)
- return kMemTypeDevice;
- if (attr == cuf::DataAttribute::Managed)
- return kMemTypeManaged;
- if (attr == cuf::DataAttribute::Unified)
- return kMemTypeUnified;
- if (attr == cuf::DataAttribute::Pinned)
- return kMemTypePinned;
- llvm::report_fatal_error("unsupported memory type");
-}
-
-template <typename OpTy>
-static bool isPinned(OpTy op) {
- if (op.getDataAttr() && *op.getDataAttr() == cuf::DataAttribute::Pinned)
- return true;
- return false;
-}
-
-template <typename OpTy>
-static bool hasDoubleDescriptors(OpTy op) {
- if (auto declareOp =
- mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp())) {
- if (mlir::isa_and_nonnull<fir::AddrOfOp>(
- declareOp.getMemref().getDefiningOp())) {
- if (isPinned(declareOp))
- return false;
- return true;
- }
- } else if (auto declareOp = mlir::dyn_cast_or_null<hlfir::DeclareOp>(
- op.getBox().getDefiningOp())) {
- if (mlir::isa_and_nonnull<fir::AddrOfOp>(
- declareOp.getMemref().getDefiningOp())) {
- if (isPinned(declareOp))
- return false;
- return true;
- }
- }
- return false;
-}
-
-static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
- mlir::Location loc, mlir::Type toTy,
- mlir::Value val) {
- if (val.getType() != toTy)
- return fir::ConvertOp::create(rewriter, loc, toTy, val);
- return val;
-}
-
-template <typename OpTy>
-static mlir::LogicalResult convertOpToCall(OpTy op,
- mlir::PatternRewriter &rewriter,
- mlir::func::FuncOp func) {
- auto mod = op->template getParentOfType<mlir::ModuleOp>();
- fir::FirOpBuilder builder(rewriter, mod);
- mlir::Location loc = op.getLoc();
- auto fTy = func.getFunctionType();
-
- mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
- mlir::Value sourceLine;
- if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>)
- sourceLine = fir::factory::locationToLineNo(
- builder, loc, op.getSource() ? fTy.getInput(7) : fTy.getInput(6));
- else
- sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
-
- mlir::Value hasStat = op.getHasStat() ? builder.createBool(loc, true)
- : builder.createBool(loc, false);
-
- mlir::Value errmsg;
- if (op.getErrmsg()) {
- errmsg = op.getErrmsg();
- } else {
- mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType());
- errmsg = fir::AbsentOp::create(builder, loc, boxNoneTy).getResult();
- }
- llvm::SmallVector<mlir::Value> args;
- if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>) {
- mlir::Value pinned =
- op.getPinned()
- ? op.getPinned()
- : builder.createNullConstant(
- loc, fir::ReferenceType::get(
- mlir::IntegerType::get(op.getContext(), 1)));
- if (op.getSource()) {
- mlir::Value stream =
- op.getStream() ? op.getStream()
- : builder.createNullConstant(loc, fTy.getInput(2));
- args = fir::runtime::createArguments(
- builder, loc, fTy, op.getBox(), op.getSource(), stream, pinned,
- hasStat, errmsg, sourceFile, sourceLine);
- } else {
- mlir::Value stream =
- op.getStream() ? op.getStream()
- : builder.createNullConstant(loc, fTy.getInput(1));
- args = fir::runtime::createArguments(builder, loc, fTy, op.getBox(),
- stream, pinned, hasStat, errmsg,
- sourceFile, sourceLine);
- }
- } else {
- args =
- fir::runtime::createArguments(builder, loc, fTy, op.getBox(), hasStat,
- errmsg, sourceFile, sourceLine);
- }
- auto callOp = fir::CallOp::create(builder, loc, func, args);
- rewriter.replaceOp(op, callOp);
- return mlir::success();
-}
-
-struct CUFAllocateOpConversion
- : public mlir::OpRewritePattern<cuf::AllocateOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult
- matchAndRewrite(cuf::AllocateOp op,
- mlir::PatternRewriter &rewriter) const override {
- auto mod = op->getParentOfType<mlir::ModuleOp>();
- fir::FirOpBuilder builder(rewriter, mod);
- mlir::Location loc = op.getLoc();
-
- bool isPointer = false;
-
- if (auto declareOp =
- mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp()))
- if (declareOp.getFortranAttrs() &&
- bitEnumContainsAny(*declareOp.getFortranAttrs(),
- fir::FortranVariableFlagsEnum::pointer))
- isPointer = true;
-
- if (hasDoubleDescriptors(op)) {
- // Allocation for module variable are done with custom runtime entry point
- // so the descriptors can be synchronized.
- mlir::func::FuncOp func;
- if (op.getSource()) {
- func = isPointer ? fir::runtime::getRuntimeFunc<mkRTKey(
- CUFPointerAllocateSourceSync)>(loc, builder)
- : fir::runtime::getRuntimeFunc<mkRTKey(
- CUFAllocatableAllocateSourceSync)>(loc, builder);
- } else {
- func =
- isPointer
- ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSync)>(
- loc, builder)
- : fir::runtime::getRuntimeFunc<mkRTKey(
- CUFAllocatableAllocateSync)>(loc, builder);
- }
- return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
- }
-
- mlir::func::FuncOp func;
- if (op.getSource()) {
- func =
- isPointer
- ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSource)>(
- loc, builder)
- : fir::runtime::getRuntimeFunc<mkRTKey(
- CUFAllocatableAllocateSource)>(loc, builder);
- } else {
- func =
- isPointer
- ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocate)>(
- loc, builder)
- : fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocate)>(
- loc, builder);
- }
-
- return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
- }
-};
-
-struct CUFDeallocateOpConversion
- : public mlir::OpRewritePattern<cuf::DeallocateOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult
- matchAndRewrite(cuf::DeallocateOp op,
- mlir::PatternRewriter &rewriter) const override {
-
- auto mod = op->getParentOfType<mlir::ModuleOp>();
- fir::FirOpBuilder builder(rewriter, mod);
- mlir::Location loc = op.getLoc();
-
- if (hasDoubleDescriptors(op)) {
- // Deallocation for module variable are done with custom runtime entry
- // point so the descriptors can be synchronized.
- mlir::func::FuncOp func =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableDeallocate)>(
- loc, builder);
- return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
- }
-
- // Deallocation for local descriptor falls back on the standard runtime
- // AllocatableDeallocate as the dedicated deallocator is set in the
- // descriptor before the call.
- mlir::func::FuncOp func =
- fir::runtime::getRuntimeFunc<mkRTKey(AllocatableDeallocate)>(loc,
- builder);
- return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
- }
-};
-
static bool inDeviceContext(mlir::Operation *op) {
if (op->getParentOfType<cuf::KernelOp>())
return true;
@@ -263,126 +64,13 @@ static bool inDeviceContext(mlir::Operation *op) {
return false;
}
-struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
- using OpRewritePattern::OpRewritePattern;
-
- CUFAllocOpConversion(mlir::MLIRContext *context, mlir::DataLayout *dl,
- const fir::LLVMTypeConverter *typeConverter)
- : OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {}
-
- mlir::LogicalResult
- matchAndRewrite(cuf::AllocOp op,
- mlir::PatternRewriter &rewriter) const override {
-
- mlir::Location loc = op.getLoc();
-
- if (inDeviceContext(op.getOperation())) {
- // In device context just replace the cuf.alloc operation with a fir.alloc
- // the cuf.free will be removed.
- auto allocaOp =
- fir::AllocaOp::create(rewriter, loc, op.getInType(),
- op.getUniqName() ? *op.getUniqName() : "",
- op.getBindcName() ? *op.getBindcName() : "",
- op.getTypeparams(), op.getShape());
- allocaOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
- rewriter.replaceOp(op, allocaOp);
- return mlir::success();
- }
-
- auto mod = op->getParentOfType<mlir::ModuleOp>();
- fir::FirOpBuilder builder(rewriter, mod);
- mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
-
- if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) {
- // Convert scalar and known size array allocations.
- mlir::Value bytes;
- fir::KindMapping kindMap{fir::getKindMapping(mod)};
- if (fir::isa_trivial(op.getInType())) {
- int width = cuf::computeElementByteSize(loc, op.getInType(), kindMap);
- bytes =
- builder.createIntegerConstant(loc, builder.getIndexType(), width);
- } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
- op.getInType())) {
- std::size_t size = 0;
- if (fir::isa_derived(seqTy.getEleTy())) {
- mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy());
- size = dl->getTypeSizeInBits(structTy) / 8;
- } else {
- size = cuf::computeElementByteSize(loc, seqTy.getEleTy(), kindMap);
- }
- mlir::Value width =
- builder.createIntegerConstant(loc, builder.getIndexType(), size);
- mlir::Value nbElem;
- if (fir::sequenceWithNonConstantShape(seqTy)) {
- assert(!op.getShape().empty() && "expect shape with dynamic arrays");
- nbElem = builder.loadIfRef(loc, op.getShape()[0]);
- for (unsigned i = 1; i < op.getShape().size(); ++i) {
- nbElem = mlir::arith::MulIOp::create(
- rewriter, loc, nbElem,
- builder.loadIfRef(loc, op.getShape()[i]));
- }
- } else {
- nbElem = builder.createIntegerConstant(loc, builder.getIndexType(),
- seqTy.getConstantArraySize());
- }
- bytes = mlir::arith::MulIOp::create(rewriter, loc, nbElem, width);
- } else if (fir::isa_derived(op.getInType())) {
- mlir::Type structTy = typeConverter->convertType(op.getInType());
- std::size_t structSize = dl->getTypeSizeInBits(structTy) / 8;
- bytes = builder.createIntegerConstant(loc, builder.getIndexType(),
- structSize);
- } else if (fir::isa_char(op.getInType())) {
- mlir::Type charTy = typeConverter->convertType(op.getInType());
- std::size_t charSize = dl->getTypeSizeInBits(charTy) / 8;
- bytes = builder.createIntegerConstant(loc, builder.getIndexType(),
- charSize);
- } else {
- mlir::emitError(loc, "unsupported type in cuf.alloc\n");
- }
- mlir::func::FuncOp func =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder);
- auto fTy = func.getFunctionType();
- mlir::Value sourceLine =
- fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
- mlir::Value memTy = builder.createIntegerConstant(
- loc, builder.getI32Type(), getMemType(op.getDataAttr()));
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
- builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)};
- auto callOp = fir::CallOp::create(builder, loc, func, args);
- callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
- auto convOp = builder.createConvert(loc, op.getResult().getType(),
- callOp.getResult(0));
- rewriter.replaceOp(op, convOp);
- return mlir::success();
- }
-
- // Convert descriptor allocations to function call.
- auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
- mlir::func::FuncOp func =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDescriptor)>(loc, builder);
- auto fTy = func.getFunctionType();
- mlir::Value sourceLine =
- fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
-
- mlir::Type structTy = typeConverter->convertBoxTypeAsStruct(boxTy);
- std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
- mlir::Value sizeInBytes =
- builder.createIntegerConstant(loc, builder.getIndexType(), boxSize);
-
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
- builder, loc, fTy, sizeInBytes, sourceFile, sourceLine)};
- auto callOp = fir::CallOp::create(builder, loc, func, args);
- callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
- auto convOp = builder.createConvert(loc, op.getResult().getType(),
- callOp.getResult(0));
- rewriter.replaceOp(op, convOp);
- return mlir::success();
- }
-
-private:
- mlir::DataLayout *dl;
- const fir::LLVMTypeConverter *typeConverter;
-};
+static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
+ mlir::Location loc, mlir::Type toTy,
+ mlir::Value val) {
+ if (val.getType() != toTy)
+ return fir::ConvertOp::create(rewriter, loc, toTy, val);
+ return val;
+}
struct CUFDeviceAddressOpConversion
: public mlir::OpRewritePattern<cuf::DeviceAddressOp> {
@@ -460,56 +148,6 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
const mlir::SymbolTable &symTab;
};
-struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult
- matchAndRewrite(cuf::FreeOp op,
- mlir::PatternRewriter &rewriter) const override {
- if (inDeviceContext(op.getOperation())) {
- rewriter.eraseOp(op);
- return mlir::success();
- }
-
- if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
- return failure();
-
- auto mod = op->getParentOfType<mlir::ModuleOp>();
- fir::FirOpBuilder builder(rewriter, mod);
- mlir::Location loc = op.getLoc();
- mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
-
- auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
- if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy())) {
- mlir::func::FuncOp func =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFMemFree)>(loc, builder);
- auto fTy = func.getFunctionType();
- mlir::Value sourceLine =
- fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
- mlir::Value memTy = builder.createIntegerConstant(
- loc, builder.getI32Type(), getMemType(op.getDataAttr()));
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
- builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)};
- fir::CallOp::create(builder, loc, func, args);
- rewriter.eraseOp(op);
- return mlir::success();
- }
-
- // Convert cuf.free on descriptors.
- mlir::func::FuncOp func =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDescriptor)>(loc, builder);
- auto fTy = func.getFunctionType();
- mlir::Value sourceLine =
- fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
- builder, loc, fTy, op.getDevptr(), sourceFile, sourceLine)};
- auto callOp = fir::CallOp::create(builder, loc, func, args);
- callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
- rewriter.eraseOp(op);
- return mlir::success();
- }
-};
-
static bool isDstGlobal(cuf::DataTransferOp op) {
if (auto declareOp = op.getDst().getDefiningOp<fir::DeclareOp>())
if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
@@ -896,6 +534,8 @@ struct CUFSyncDescriptorOpConversion
};
class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
+ using CUFOpConversionBase::CUFOpConversionBase;
+
public:
void runOnOperation() override {
auto *ctx = &getContext();
@@ -917,6 +557,8 @@ class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
target.addLegalOp<cuf::StreamCastOp>();
cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab,
patterns);
+ cuf::populateCUFAllocationConversionPatterns(typeConverter, *dl, symtab,
+ patterns);
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
mlir::emitError(mlir::UnknownLoc::get(ctx),
@@ -956,10 +598,7 @@ class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
void cuf::populateCUFToFIRConversionPatterns(
const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
- patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter);
- patterns.insert<CUFAllocateOpConversion, CUFDeallocateOpConversion,
- CUFFreeOpConversion, CUFSyncDescriptorOpConversion>(
- patterns.getContext());
+ patterns.insert<CUFSyncDescriptorOpConversion>(patterns.getContext());
patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
&dl, &converter);
patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>(
More information about the flang-commits
mailing list