[flang-commits] [flang] 21e64d1 - [flang][cuda][NFC] Split allocation related operation conversion from other cuf operations (#169740)

via flang-commits flang-commits at lists.llvm.org
Mon Dec 1 10:19:56 PST 2025


Author: Valentin Clement (バレンタイン クレメン)
Date: 2025-12-01T10:19:52-08:00
New Revision: 21e64d1f5a3dbf539eaf9c7ac160469e60222ba2

URL: https://github.com/llvm/llvm-project/commit/21e64d1f5a3dbf539eaf9c7ac160469e60222ba2
DIFF: https://github.com/llvm/llvm-project/commit/21e64d1f5a3dbf539eaf9c7ac160469e60222ba2.diff

LOG: [flang][cuda][NFC] Split allocation related operation conversion from other cuf operations (#169740)

Split AllocOp, FreeOp, AllocateOp and DeallocateOp from other
conversion. Patterns are currently added to the base CUFOpConversion
when the option is enabled.
This split is a pre-requisite to be more flexible where we do the
allocation related operations conversion in the pipeline.

Added: 
    flang/include/flang/Optimizer/Transforms/CUDA/CUFAllocationConversion.h
    flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp

Modified: 
    flang/include/flang/Optimizer/Transforms/Passes.td
    flang/lib/Optimizer/Transforms/CMakeLists.txt
    flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Transforms/CUDA/CUFAllocationConversion.h b/flang/include/flang/Optimizer/Transforms/CUDA/CUFAllocationConversion.h
new file mode 100644
index 0000000000000..2a4eb1cdb27f0
--- /dev/null
+++ b/flang/include/flang/Optimizer/Transforms/CUDA/CUFAllocationConversion.h
@@ -0,0 +1,33 @@
+//===------- CUFAllocationConversion.h --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_OPTIMIZER_TRANSFORMS_CUDA_CUFALLOCATIONCONVERSION_H_
+#define FORTRAN_OPTIMIZER_TRANSFORMS_CUDA_CUFALLOCATIONCONVERSION_H_
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+namespace fir {
+class LLVMTypeConverter;
+}
+
+namespace mlir {
+class DataLayout;
+class SymbolTable;
+} // namespace mlir
+
+namespace cuf {
+
+/// Patterns that convert CUF operations to runtime calls.
+void populateCUFAllocationConversionPatterns(
+    const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
+    const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns);
+
+} // namespace cuf
+
+#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUDA_CUFALLOCATIONCONVERSION_H_

diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index f5403ab6ff503..f50202784e2dc 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -470,11 +470,19 @@ def AssumedRankOpConversion : Pass<"fir-assumed-rank-op", "mlir::ModuleOp"> {
   ];
 }
 
+def CUFAllocationConversion : Pass<"cuf-allocation-convert", "mlir::ModuleOp"> {
+  let summary = "Convert allocation related CUF operations to runtime calls";
+  let dependentDialects = ["fir::FIROpsDialect"];
+}
+
 def CUFOpConversion : Pass<"cuf-convert", "mlir::ModuleOp"> {
   let summary = "Convert some CUF operations to runtime calls";
   let dependentDialects = [
     "fir::FIROpsDialect", "mlir::gpu::GPUDialect", "mlir::DLTIDialect"
   ];
+  let options = [Option<"allocationConversion", "allocation-conversion", "bool",
+                        /*default=*/"true",
+                        "Convert allocation related operation with this pass">];
 }
 
 def CUFDeviceGlobal :

diff  --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 0388439f89a54..619f3adc67c85 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -9,6 +9,7 @@ add_flang_library(FIRTransforms
   CompilerGeneratedNames.cpp
   ConstantArgumentGlobalisation.cpp
   ControlFlowConverter.cpp
+  CUDA/CUFAllocationConversion.cpp
   CUFAddConstructor.cpp
   CUFDeviceGlobal.cpp
   CUFOpConversion.cpp

diff  --git a/flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp
new file mode 100644
index 0000000000000..0acdb24bf62b1
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp
@@ -0,0 +1,468 @@
+//===-- CUFAllocationConversion.cpp ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Transforms/CUDA/CUFAllocationConversion.h"
+#include "flang/Optimizer/Builder/CUFCommon.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h"
+#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
+#include "flang/Optimizer/CodeGen/TypeConverter.h"
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/Support/DataLayout.h"
+#include "flang/Runtime/CUDA/allocatable.h"
+#include "flang/Runtime/CUDA/common.h"
+#include "flang/Runtime/CUDA/descriptor.h"
+#include "flang/Runtime/CUDA/memory.h"
+#include "flang/Runtime/CUDA/pointer.h"
+#include "flang/Runtime/allocatable.h"
+#include "flang/Runtime/allocator-registry-consts.h"
+#include "flang/Support/Fortran.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace fir {
+#define GEN_PASS_DEF_CUFALLOCATIONCONVERSION
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace fir;
+using namespace mlir;
+using namespace Fortran::runtime;
+using namespace Fortran::runtime::cuda;
+
+namespace {
+
+template <typename OpTy>
+static bool isPinned(OpTy op) {
+  if (op.getDataAttr() && *op.getDataAttr() == cuf::DataAttribute::Pinned)
+    return true;
+  return false;
+}
+
+static inline unsigned getMemType(cuf::DataAttribute attr) {
+  if (attr == cuf::DataAttribute::Device)
+    return kMemTypeDevice;
+  if (attr == cuf::DataAttribute::Managed)
+    return kMemTypeManaged;
+  if (attr == cuf::DataAttribute::Pinned)
+    return kMemTypePinned;
+  if (attr == cuf::DataAttribute::Unified)
+    return kMemTypeUnified;
+  llvm_unreachable("unsupported memory type");
+}
+
+template <typename OpTy>
+static bool hasDoubleDescriptors(OpTy op) {
+  if (auto declareOp =
+          mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp())) {
+    if (mlir::isa_and_nonnull<fir::AddrOfOp>(
+            declareOp.getMemref().getDefiningOp())) {
+      if (isPinned(declareOp))
+        return false;
+      return true;
+    }
+  } else if (auto declareOp = mlir::dyn_cast_or_null<hlfir::DeclareOp>(
+                 op.getBox().getDefiningOp())) {
+    if (mlir::isa_and_nonnull<fir::AddrOfOp>(
+            declareOp.getMemref().getDefiningOp())) {
+      if (isPinned(declareOp))
+        return false;
+      return true;
+    }
+  }
+  return false;
+}
+
+static bool inDeviceContext(mlir::Operation *op) {
+  if (op->getParentOfType<cuf::KernelOp>())
+    return true;
+  if (auto funcOp = op->getParentOfType<mlir::gpu::GPUFuncOp>())
+    return true;
+  if (auto funcOp = op->getParentOfType<mlir::gpu::LaunchOp>())
+    return true;
+  if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) {
+    if (auto cudaProcAttr =
+            funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
+                cuf::getProcAttrName())) {
+      return cudaProcAttr.getValue() != cuf::ProcAttribute::Host &&
+             cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice;
+    }
+  }
+  return false;
+}
+
+template <typename OpTy>
+static mlir::LogicalResult convertOpToCall(OpTy op,
+                                           mlir::PatternRewriter &rewriter,
+                                           mlir::func::FuncOp func) {
+  auto mod = op->template getParentOfType<mlir::ModuleOp>();
+  fir::FirOpBuilder builder(rewriter, mod);
+  mlir::Location loc = op.getLoc();
+  auto fTy = func.getFunctionType();
+
+  mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+  mlir::Value sourceLine;
+  if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>)
+    sourceLine = fir::factory::locationToLineNo(
+        builder, loc, op.getSource() ? fTy.getInput(7) : fTy.getInput(6));
+  else
+    sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
+
+  mlir::Value hasStat = op.getHasStat() ? builder.createBool(loc, true)
+                                        : builder.createBool(loc, false);
+
+  mlir::Value errmsg;
+  if (op.getErrmsg()) {
+    errmsg = op.getErrmsg();
+  } else {
+    mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType());
+    errmsg = fir::AbsentOp::create(builder, loc, boxNoneTy).getResult();
+  }
+  llvm::SmallVector<mlir::Value> args;
+  if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>) {
+    mlir::Value pinned =
+        op.getPinned()
+            ? op.getPinned()
+            : builder.createNullConstant(
+                  loc, fir::ReferenceType::get(
+                           mlir::IntegerType::get(op.getContext(), 1)));
+    if (op.getSource()) {
+      mlir::Value stream =
+          op.getStream() ? op.getStream()
+                         : builder.createNullConstant(loc, fTy.getInput(2));
+      args = fir::runtime::createArguments(
+          builder, loc, fTy, op.getBox(), op.getSource(), stream, pinned,
+          hasStat, errmsg, sourceFile, sourceLine);
+    } else {
+      mlir::Value stream =
+          op.getStream() ? op.getStream()
+                         : builder.createNullConstant(loc, fTy.getInput(1));
+      args = fir::runtime::createArguments(builder, loc, fTy, op.getBox(),
+                                           stream, pinned, hasStat, errmsg,
+                                           sourceFile, sourceLine);
+    }
+  } else {
+    args =
+        fir::runtime::createArguments(builder, loc, fTy, op.getBox(), hasStat,
+                                      errmsg, sourceFile, sourceLine);
+  }
+  auto callOp = fir::CallOp::create(builder, loc, func, args);
+  rewriter.replaceOp(op, callOp);
+  return mlir::success();
+}
+
+struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  CUFAllocOpConversion(mlir::MLIRContext *context, mlir::DataLayout *dl,
+                       const fir::LLVMTypeConverter *typeConverter)
+      : OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {}
+
+  mlir::LogicalResult
+  matchAndRewrite(cuf::AllocOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+
+    mlir::Location loc = op.getLoc();
+
+    if (inDeviceContext(op.getOperation())) {
+      // In device context just replace the cuf.alloc operation with a fir.alloc
+      // the cuf.free will be removed.
+      auto allocaOp =
+          fir::AllocaOp::create(rewriter, loc, op.getInType(),
+                                op.getUniqName() ? *op.getUniqName() : "",
+                                op.getBindcName() ? *op.getBindcName() : "",
+                                op.getTypeparams(), op.getShape());
+      allocaOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
+      rewriter.replaceOp(op, allocaOp);
+      return mlir::success();
+    }
+
+    auto mod = op->getParentOfType<mlir::ModuleOp>();
+    fir::FirOpBuilder builder(rewriter, mod);
+    mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+
+    if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) {
+      // Convert scalar and known size array allocations.
+      mlir::Value bytes;
+      fir::KindMapping kindMap{fir::getKindMapping(mod)};
+      if (fir::isa_trivial(op.getInType())) {
+        int width = cuf::computeElementByteSize(loc, op.getInType(), kindMap);
+        bytes =
+            builder.createIntegerConstant(loc, builder.getIndexType(), width);
+      } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
+                     op.getInType())) {
+        std::size_t size = 0;
+        if (fir::isa_derived(seqTy.getEleTy())) {
+          mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy());
+          size = dl->getTypeSizeInBits(structTy) / 8;
+        } else {
+          size = cuf::computeElementByteSize(loc, seqTy.getEleTy(), kindMap);
+        }
+        mlir::Value width =
+            builder.createIntegerConstant(loc, builder.getIndexType(), size);
+        mlir::Value nbElem;
+        if (fir::sequenceWithNonConstantShape(seqTy)) {
+          assert(!op.getShape().empty() && "expect shape with dynamic arrays");
+          nbElem = builder.loadIfRef(loc, op.getShape()[0]);
+          for (unsigned i = 1; i < op.getShape().size(); ++i) {
+            nbElem = mlir::arith::MulIOp::create(
+                rewriter, loc, nbElem,
+                builder.loadIfRef(loc, op.getShape()[i]));
+          }
+        } else {
+          nbElem = builder.createIntegerConstant(loc, builder.getIndexType(),
+                                                 seqTy.getConstantArraySize());
+        }
+        bytes = mlir::arith::MulIOp::create(rewriter, loc, nbElem, width);
+      } else if (fir::isa_derived(op.getInType())) {
+        mlir::Type structTy = typeConverter->convertType(op.getInType());
+        std::size_t structSize = dl->getTypeSizeInBits(structTy) / 8;
+        bytes = builder.createIntegerConstant(loc, builder.getIndexType(),
+                                              structSize);
+      } else if (fir::isa_char(op.getInType())) {
+        mlir::Type charTy = typeConverter->convertType(op.getInType());
+        std::size_t charSize = dl->getTypeSizeInBits(charTy) / 8;
+        bytes = builder.createIntegerConstant(loc, builder.getIndexType(),
+                                              charSize);
+      } else {
+        mlir::emitError(loc, "unsupported type in cuf.alloc\n");
+      }
+      mlir::func::FuncOp func =
+          fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder);
+      auto fTy = func.getFunctionType();
+      mlir::Value sourceLine =
+          fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
+      mlir::Value memTy = builder.createIntegerConstant(
+          loc, builder.getI32Type(), getMemType(op.getDataAttr()));
+      llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+          builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)};
+      auto callOp = fir::CallOp::create(builder, loc, func, args);
+      callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
+      auto convOp = builder.createConvert(loc, op.getResult().getType(),
+                                          callOp.getResult(0));
+      rewriter.replaceOp(op, convOp);
+      return mlir::success();
+    }
+
+    // Convert descriptor allocations to function call.
+    auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
+    mlir::func::FuncOp func =
+        fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDescriptor)>(loc, builder);
+    auto fTy = func.getFunctionType();
+    mlir::Value sourceLine =
+        fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+
+    mlir::Type structTy = typeConverter->convertBoxTypeAsStruct(boxTy);
+    std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
+    mlir::Value sizeInBytes =
+        builder.createIntegerConstant(loc, builder.getIndexType(), boxSize);
+
+    llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+        builder, loc, fTy, sizeInBytes, sourceFile, sourceLine)};
+    auto callOp = fir::CallOp::create(builder, loc, func, args);
+    callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
+    auto convOp = builder.createConvert(loc, op.getResult().getType(),
+                                        callOp.getResult(0));
+    rewriter.replaceOp(op, convOp);
+    return mlir::success();
+  }
+
+private:
+  mlir::DataLayout *dl;
+  const fir::LLVMTypeConverter *typeConverter;
+};
+
+struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(cuf::FreeOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    if (inDeviceContext(op.getOperation())) {
+      rewriter.eraseOp(op);
+      return mlir::success();
+    }
+
+    if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
+      return failure();
+
+    auto mod = op->getParentOfType<mlir::ModuleOp>();
+    fir::FirOpBuilder builder(rewriter, mod);
+    mlir::Location loc = op.getLoc();
+    mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+
+    auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
+    if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy())) {
+      mlir::func::FuncOp func =
+          fir::runtime::getRuntimeFunc<mkRTKey(CUFMemFree)>(loc, builder);
+      auto fTy = func.getFunctionType();
+      mlir::Value sourceLine =
+          fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
+      mlir::Value memTy = builder.createIntegerConstant(
+          loc, builder.getI32Type(), getMemType(op.getDataAttr()));
+      llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+          builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)};
+      fir::CallOp::create(builder, loc, func, args);
+      rewriter.eraseOp(op);
+      return mlir::success();
+    }
+
+    // Convert cuf.free on descriptors.
+    mlir::func::FuncOp func =
+        fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDescriptor)>(loc, builder);
+    auto fTy = func.getFunctionType();
+    mlir::Value sourceLine =
+        fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+    llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+        builder, loc, fTy, op.getDevptr(), sourceFile, sourceLine)};
+    auto callOp = fir::CallOp::create(builder, loc, func, args);
+    callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
+    rewriter.eraseOp(op);
+    return mlir::success();
+  }
+};
+
+struct CUFAllocateOpConversion
+    : public mlir::OpRewritePattern<cuf::AllocateOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(cuf::AllocateOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    auto mod = op->getParentOfType<mlir::ModuleOp>();
+    fir::FirOpBuilder builder(rewriter, mod);
+    mlir::Location loc = op.getLoc();
+
+    bool isPointer = false;
+
+    if (auto declareOp =
+            mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp()))
+      if (declareOp.getFortranAttrs() &&
+          bitEnumContainsAny(*declareOp.getFortranAttrs(),
+                             fir::FortranVariableFlagsEnum::pointer))
+        isPointer = true;
+
+    if (hasDoubleDescriptors(op)) {
+      // Allocation for module variable are done with custom runtime entry point
+      // so the descriptors can be synchronized.
+      mlir::func::FuncOp func;
+      if (op.getSource()) {
+        func = isPointer ? fir::runtime::getRuntimeFunc<mkRTKey(
+                               CUFPointerAllocateSourceSync)>(loc, builder)
+                         : fir::runtime::getRuntimeFunc<mkRTKey(
+                               CUFAllocatableAllocateSourceSync)>(loc, builder);
+      } else {
+        func =
+            isPointer
+                ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSync)>(
+                      loc, builder)
+                : fir::runtime::getRuntimeFunc<mkRTKey(
+                      CUFAllocatableAllocateSync)>(loc, builder);
+      }
+      return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
+    }
+
+    mlir::func::FuncOp func;
+    if (op.getSource()) {
+      func =
+          isPointer
+              ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSource)>(
+                    loc, builder)
+              : fir::runtime::getRuntimeFunc<mkRTKey(
+                    CUFAllocatableAllocateSource)>(loc, builder);
+    } else {
+      func =
+          isPointer
+              ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocate)>(
+                    loc, builder)
+              : fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocate)>(
+                    loc, builder);
+    }
+
+    return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
+  }
+};
+
+struct CUFDeallocateOpConversion
+    : public mlir::OpRewritePattern<cuf::DeallocateOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(cuf::DeallocateOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+
+    auto mod = op->getParentOfType<mlir::ModuleOp>();
+    fir::FirOpBuilder builder(rewriter, mod);
+    mlir::Location loc = op.getLoc();
+
+    if (hasDoubleDescriptors(op)) {
+      // Deallocation for module variable are done with custom runtime entry
+      // point so the descriptors can be synchronized.
+      mlir::func::FuncOp func =
+          fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableDeallocate)>(
+              loc, builder);
+      return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
+    }
+
+    // Deallocation for local descriptor falls back on the standard runtime
+    // AllocatableDeallocate as the dedicated deallocator is set in the
+    // descriptor before the call.
+    mlir::func::FuncOp func =
+        fir::runtime::getRuntimeFunc<mkRTKey(AllocatableDeallocate)>(loc,
+                                                                     builder);
+    return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
+  }
+};
+
+class CUFAllocationConversion
+    : public fir::impl::CUFAllocationConversionBase<CUFAllocationConversion> {
+public:
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    mlir::RewritePatternSet patterns(ctx);
+    mlir::ConversionTarget target(*ctx);
+
+    mlir::Operation *op = getOperation();
+    mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
+    if (!module)
+      return signalPassFailure();
+    mlir::SymbolTable symtab(module);
+
+    std::optional<mlir::DataLayout> dl = fir::support::getOrSetMLIRDataLayout(
+        module, /*allowDefaultLayout=*/false);
+    fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
+                                         /*forceUnifiedTBAATree=*/false, *dl);
+    target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
+                           mlir::gpu::GPUDialect>();
+    target.addLegalOp<cuf::StreamCastOp>();
+    cuf::populateCUFAllocationConversionPatterns(typeConverter, *dl, symtab,
+                                                 patterns);
+    if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
+                                                  std::move(patterns)))) {
+      mlir::emitError(mlir::UnknownLoc::get(ctx),
+                      "error in CUF allocation conversion\n");
+      signalPassFailure();
+    }
+  }
+};
+
+} // namespace
+
+void cuf::populateCUFAllocationConversionPatterns(
+    const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
+    const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
+  patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter);
+  patterns.insert<CUFFreeOpConversion, CUFAllocateOpConversion,
+                  CUFDeallocateOpConversion>(patterns.getContext());
+}

diff  --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 7ed34f865d0e9..f2ab99a8bc8ee 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -16,6 +16,8 @@
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/Support/DataLayout.h"
+#include "flang/Optimizer/Transforms/CUDA/CUFAllocationConversion.h"
+#include "flang/Optimizer/Transforms/Passes.h"
 #include "flang/Runtime/CUDA/allocatable.h"
 #include "flang/Runtime/CUDA/common.h"
 #include "flang/Runtime/CUDA/descriptor.h"
@@ -44,207 +46,6 @@ using namespace Fortran::runtime::cuda;
 
 namespace {
 
-static inline unsigned getMemType(cuf::DataAttribute attr) {
-  if (attr == cuf::DataAttribute::Device)
-    return kMemTypeDevice;
-  if (attr == cuf::DataAttribute::Managed)
-    return kMemTypeManaged;
-  if (attr == cuf::DataAttribute::Unified)
-    return kMemTypeUnified;
-  if (attr == cuf::DataAttribute::Pinned)
-    return kMemTypePinned;
-  llvm::report_fatal_error("unsupported memory type");
-}
-
-template <typename OpTy>
-static bool isPinned(OpTy op) {
-  if (op.getDataAttr() && *op.getDataAttr() == cuf::DataAttribute::Pinned)
-    return true;
-  return false;
-}
-
-template <typename OpTy>
-static bool hasDoubleDescriptors(OpTy op) {
-  if (auto declareOp =
-          mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp())) {
-    if (mlir::isa_and_nonnull<fir::AddrOfOp>(
-            declareOp.getMemref().getDefiningOp())) {
-      if (isPinned(declareOp))
-        return false;
-      return true;
-    }
-  } else if (auto declareOp = mlir::dyn_cast_or_null<hlfir::DeclareOp>(
-                 op.getBox().getDefiningOp())) {
-    if (mlir::isa_and_nonnull<fir::AddrOfOp>(
-            declareOp.getMemref().getDefiningOp())) {
-      if (isPinned(declareOp))
-        return false;
-      return true;
-    }
-  }
-  return false;
-}
-
-static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
-                                   mlir::Location loc, mlir::Type toTy,
-                                   mlir::Value val) {
-  if (val.getType() != toTy)
-    return fir::ConvertOp::create(rewriter, loc, toTy, val);
-  return val;
-}
-
-template <typename OpTy>
-static mlir::LogicalResult convertOpToCall(OpTy op,
-                                           mlir::PatternRewriter &rewriter,
-                                           mlir::func::FuncOp func) {
-  auto mod = op->template getParentOfType<mlir::ModuleOp>();
-  fir::FirOpBuilder builder(rewriter, mod);
-  mlir::Location loc = op.getLoc();
-  auto fTy = func.getFunctionType();
-
-  mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
-  mlir::Value sourceLine;
-  if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>)
-    sourceLine = fir::factory::locationToLineNo(
-        builder, loc, op.getSource() ? fTy.getInput(7) : fTy.getInput(6));
-  else
-    sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
-
-  mlir::Value hasStat = op.getHasStat() ? builder.createBool(loc, true)
-                                        : builder.createBool(loc, false);
-
-  mlir::Value errmsg;
-  if (op.getErrmsg()) {
-    errmsg = op.getErrmsg();
-  } else {
-    mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType());
-    errmsg = fir::AbsentOp::create(builder, loc, boxNoneTy).getResult();
-  }
-  llvm::SmallVector<mlir::Value> args;
-  if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>) {
-    mlir::Value pinned =
-        op.getPinned()
-            ? op.getPinned()
-            : builder.createNullConstant(
-                  loc, fir::ReferenceType::get(
-                           mlir::IntegerType::get(op.getContext(), 1)));
-    if (op.getSource()) {
-      mlir::Value stream =
-          op.getStream() ? op.getStream()
-                         : builder.createNullConstant(loc, fTy.getInput(2));
-      args = fir::runtime::createArguments(
-          builder, loc, fTy, op.getBox(), op.getSource(), stream, pinned,
-          hasStat, errmsg, sourceFile, sourceLine);
-    } else {
-      mlir::Value stream =
-          op.getStream() ? op.getStream()
-                         : builder.createNullConstant(loc, fTy.getInput(1));
-      args = fir::runtime::createArguments(builder, loc, fTy, op.getBox(),
-                                           stream, pinned, hasStat, errmsg,
-                                           sourceFile, sourceLine);
-    }
-  } else {
-    args =
-        fir::runtime::createArguments(builder, loc, fTy, op.getBox(), hasStat,
-                                      errmsg, sourceFile, sourceLine);
-  }
-  auto callOp = fir::CallOp::create(builder, loc, func, args);
-  rewriter.replaceOp(op, callOp);
-  return mlir::success();
-}
-
-struct CUFAllocateOpConversion
-    : public mlir::OpRewritePattern<cuf::AllocateOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  mlir::LogicalResult
-  matchAndRewrite(cuf::AllocateOp op,
-                  mlir::PatternRewriter &rewriter) const override {
-    auto mod = op->getParentOfType<mlir::ModuleOp>();
-    fir::FirOpBuilder builder(rewriter, mod);
-    mlir::Location loc = op.getLoc();
-
-    bool isPointer = false;
-
-    if (auto declareOp =
-            mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp()))
-      if (declareOp.getFortranAttrs() &&
-          bitEnumContainsAny(*declareOp.getFortranAttrs(),
-                             fir::FortranVariableFlagsEnum::pointer))
-        isPointer = true;
-
-    if (hasDoubleDescriptors(op)) {
-      // Allocation for module variable are done with custom runtime entry point
-      // so the descriptors can be synchronized.
-      mlir::func::FuncOp func;
-      if (op.getSource()) {
-        func = isPointer ? fir::runtime::getRuntimeFunc<mkRTKey(
-                               CUFPointerAllocateSourceSync)>(loc, builder)
-                         : fir::runtime::getRuntimeFunc<mkRTKey(
-                               CUFAllocatableAllocateSourceSync)>(loc, builder);
-      } else {
-        func =
-            isPointer
-                ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSync)>(
-                      loc, builder)
-                : fir::runtime::getRuntimeFunc<mkRTKey(
-                      CUFAllocatableAllocateSync)>(loc, builder);
-      }
-      return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
-    }
-
-    mlir::func::FuncOp func;
-    if (op.getSource()) {
-      func =
-          isPointer
-              ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSource)>(
-                    loc, builder)
-              : fir::runtime::getRuntimeFunc<mkRTKey(
-                    CUFAllocatableAllocateSource)>(loc, builder);
-    } else {
-      func =
-          isPointer
-              ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocate)>(
-                    loc, builder)
-              : fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocate)>(
-                    loc, builder);
-    }
-
-    return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
-  }
-};
-
-struct CUFDeallocateOpConversion
-    : public mlir::OpRewritePattern<cuf::DeallocateOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  mlir::LogicalResult
-  matchAndRewrite(cuf::DeallocateOp op,
-                  mlir::PatternRewriter &rewriter) const override {
-
-    auto mod = op->getParentOfType<mlir::ModuleOp>();
-    fir::FirOpBuilder builder(rewriter, mod);
-    mlir::Location loc = op.getLoc();
-
-    if (hasDoubleDescriptors(op)) {
-      // Deallocation for module variable are done with custom runtime entry
-      // point so the descriptors can be synchronized.
-      mlir::func::FuncOp func =
-          fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableDeallocate)>(
-              loc, builder);
-      return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
-    }
-
-    // Deallocation for local descriptor falls back on the standard runtime
-    // AllocatableDeallocate as the dedicated deallocator is set in the
-    // descriptor before the call.
-    mlir::func::FuncOp func =
-        fir::runtime::getRuntimeFunc<mkRTKey(AllocatableDeallocate)>(loc,
-                                                                     builder);
-    return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
-  }
-};
-
 static bool inDeviceContext(mlir::Operation *op) {
   if (op->getParentOfType<cuf::KernelOp>())
     return true;
@@ -263,126 +64,13 @@ static bool inDeviceContext(mlir::Operation *op) {
   return false;
 }
 
-struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  CUFAllocOpConversion(mlir::MLIRContext *context, mlir::DataLayout *dl,
-                       const fir::LLVMTypeConverter *typeConverter)
-      : OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {}
-
-  mlir::LogicalResult
-  matchAndRewrite(cuf::AllocOp op,
-                  mlir::PatternRewriter &rewriter) const override {
-
-    mlir::Location loc = op.getLoc();
-
-    if (inDeviceContext(op.getOperation())) {
-      // In device context just replace the cuf.alloc operation with a fir.alloc
-      // the cuf.free will be removed.
-      auto allocaOp =
-          fir::AllocaOp::create(rewriter, loc, op.getInType(),
-                                op.getUniqName() ? *op.getUniqName() : "",
-                                op.getBindcName() ? *op.getBindcName() : "",
-                                op.getTypeparams(), op.getShape());
-      allocaOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
-      rewriter.replaceOp(op, allocaOp);
-      return mlir::success();
-    }
-
-    auto mod = op->getParentOfType<mlir::ModuleOp>();
-    fir::FirOpBuilder builder(rewriter, mod);
-    mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
-
-    if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) {
-      // Convert scalar and known size array allocations.
-      mlir::Value bytes;
-      fir::KindMapping kindMap{fir::getKindMapping(mod)};
-      if (fir::isa_trivial(op.getInType())) {
-        int width = cuf::computeElementByteSize(loc, op.getInType(), kindMap);
-        bytes =
-            builder.createIntegerConstant(loc, builder.getIndexType(), width);
-      } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
-                     op.getInType())) {
-        std::size_t size = 0;
-        if (fir::isa_derived(seqTy.getEleTy())) {
-          mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy());
-          size = dl->getTypeSizeInBits(structTy) / 8;
-        } else {
-          size = cuf::computeElementByteSize(loc, seqTy.getEleTy(), kindMap);
-        }
-        mlir::Value width =
-            builder.createIntegerConstant(loc, builder.getIndexType(), size);
-        mlir::Value nbElem;
-        if (fir::sequenceWithNonConstantShape(seqTy)) {
-          assert(!op.getShape().empty() && "expect shape with dynamic arrays");
-          nbElem = builder.loadIfRef(loc, op.getShape()[0]);
-          for (unsigned i = 1; i < op.getShape().size(); ++i) {
-            nbElem = mlir::arith::MulIOp::create(
-                rewriter, loc, nbElem,
-                builder.loadIfRef(loc, op.getShape()[i]));
-          }
-        } else {
-          nbElem = builder.createIntegerConstant(loc, builder.getIndexType(),
-                                                 seqTy.getConstantArraySize());
-        }
-        bytes = mlir::arith::MulIOp::create(rewriter, loc, nbElem, width);
-      } else if (fir::isa_derived(op.getInType())) {
-        mlir::Type structTy = typeConverter->convertType(op.getInType());
-        std::size_t structSize = dl->getTypeSizeInBits(structTy) / 8;
-        bytes = builder.createIntegerConstant(loc, builder.getIndexType(),
-                                              structSize);
-      } else if (fir::isa_char(op.getInType())) {
-        mlir::Type charTy = typeConverter->convertType(op.getInType());
-        std::size_t charSize = dl->getTypeSizeInBits(charTy) / 8;
-        bytes = builder.createIntegerConstant(loc, builder.getIndexType(),
-                                              charSize);
-      } else {
-        mlir::emitError(loc, "unsupported type in cuf.alloc\n");
-      }
-      mlir::func::FuncOp func =
-          fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder);
-      auto fTy = func.getFunctionType();
-      mlir::Value sourceLine =
-          fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
-      mlir::Value memTy = builder.createIntegerConstant(
-          loc, builder.getI32Type(), getMemType(op.getDataAttr()));
-      llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
-          builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)};
-      auto callOp = fir::CallOp::create(builder, loc, func, args);
-      callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
-      auto convOp = builder.createConvert(loc, op.getResult().getType(),
-                                          callOp.getResult(0));
-      rewriter.replaceOp(op, convOp);
-      return mlir::success();
-    }
-
-    // Convert descriptor allocations to function call.
-    auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
-    mlir::func::FuncOp func =
-        fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDescriptor)>(loc, builder);
-    auto fTy = func.getFunctionType();
-    mlir::Value sourceLine =
-        fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
-
-    mlir::Type structTy = typeConverter->convertBoxTypeAsStruct(boxTy);
-    std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
-    mlir::Value sizeInBytes =
-        builder.createIntegerConstant(loc, builder.getIndexType(), boxSize);
-
-    llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
-        builder, loc, fTy, sizeInBytes, sourceFile, sourceLine)};
-    auto callOp = fir::CallOp::create(builder, loc, func, args);
-    callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
-    auto convOp = builder.createConvert(loc, op.getResult().getType(),
-                                        callOp.getResult(0));
-    rewriter.replaceOp(op, convOp);
-    return mlir::success();
-  }
-
-private:
-  mlir::DataLayout *dl;
-  const fir::LLVMTypeConverter *typeConverter;
-};
+static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
+                                   mlir::Location loc, mlir::Type toTy,
+                                   mlir::Value val) {
+  if (val.getType() != toTy)
+    return fir::ConvertOp::create(rewriter, loc, toTy, val);
+  return val;
+}
 
 struct CUFDeviceAddressOpConversion
     : public mlir::OpRewritePattern<cuf::DeviceAddressOp> {
@@ -460,56 +148,6 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
   const mlir::SymbolTable &symTab;
 };
 
-struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  mlir::LogicalResult
-  matchAndRewrite(cuf::FreeOp op,
-                  mlir::PatternRewriter &rewriter) const override {
-    if (inDeviceContext(op.getOperation())) {
-      rewriter.eraseOp(op);
-      return mlir::success();
-    }
-
-    if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
-      return failure();
-
-    auto mod = op->getParentOfType<mlir::ModuleOp>();
-    fir::FirOpBuilder builder(rewriter, mod);
-    mlir::Location loc = op.getLoc();
-    mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
-
-    auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
-    if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy())) {
-      mlir::func::FuncOp func =
-          fir::runtime::getRuntimeFunc<mkRTKey(CUFMemFree)>(loc, builder);
-      auto fTy = func.getFunctionType();
-      mlir::Value sourceLine =
-          fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
-      mlir::Value memTy = builder.createIntegerConstant(
-          loc, builder.getI32Type(), getMemType(op.getDataAttr()));
-      llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
-          builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)};
-      fir::CallOp::create(builder, loc, func, args);
-      rewriter.eraseOp(op);
-      return mlir::success();
-    }
-
-    // Convert cuf.free on descriptors.
-    mlir::func::FuncOp func =
-        fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDescriptor)>(loc, builder);
-    auto fTy = func.getFunctionType();
-    mlir::Value sourceLine =
-        fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
-    llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
-        builder, loc, fTy, op.getDevptr(), sourceFile, sourceLine)};
-    auto callOp = fir::CallOp::create(builder, loc, func, args);
-    callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
-    rewriter.eraseOp(op);
-    return mlir::success();
-  }
-};
-
 static bool isDstGlobal(cuf::DataTransferOp op) {
   if (auto declareOp = op.getDst().getDefiningOp<fir::DeclareOp>())
     if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
@@ -896,6 +534,8 @@ struct CUFSyncDescriptorOpConversion
 };
 
 class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
+  using CUFOpConversionBase::CUFOpConversionBase;
+
 public:
   void runOnOperation() override {
     auto *ctx = &getContext();
@@ -917,6 +557,8 @@ class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
     target.addLegalOp<cuf::StreamCastOp>();
     cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab,
                                             patterns);
+    cuf::populateCUFAllocationConversionPatterns(typeConverter, *dl, symtab,
+                                                 patterns);
     if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
                                                   std::move(patterns)))) {
       mlir::emitError(mlir::UnknownLoc::get(ctx),
@@ -956,10 +598,7 @@ class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
 void cuf::populateCUFToFIRConversionPatterns(
     const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
     const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
-  patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter);
-  patterns.insert<CUFAllocateOpConversion, CUFDeallocateOpConversion,
-                  CUFFreeOpConversion, CUFSyncDescriptorOpConversion>(
-      patterns.getContext());
+  patterns.insert<CUFSyncDescriptorOpConversion>(patterns.getContext());
   patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
                                                &dl, &converter);
   patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>(


        


More information about the flang-commits mailing list