[flang-commits] [flang] [flang][cuda] Support malloc and free conversion in gpu module (PR #116112)
via flang-commits
flang-commits at lists.llvm.org
Wed Nov 13 14:00:07 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
Update `getMalloc` and `getFree` to work with the enclosing module (ModuleOp or GPUModuleOp) so we can convert `fir.allocmem` and `fir.freemem` in device code.
---
Full diff: https://github.com/llvm/llvm-project/pull/116112.diff
3 Files Affected:
- (modified) flang/lib/Optimizer/CodeGen/CMakeLists.txt (+1)
- (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+36-13)
- (modified) flang/test/Fir/convert-to-llvm.fir (+16)
``````````diff
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 646621cb01c157..f47d11875f04db 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -23,6 +23,7 @@ add_flang_library(FIRCodeGen
FIRSupport
MLIRComplexToLLVM
MLIRComplexToStandard
+ MLIRGPUDialect
MLIRMathToFuncs
MLIRMathToLLVM
MLIRMathToLibm
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index d038efcb2eb42c..3452a662f7a194 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -41,6 +41,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
@@ -920,17 +921,19 @@ struct EmboxCharOpConversion : public fir::FIROpConversion<fir::EmboxCharOp> {
};
} // namespace
-/// Return the LLVMFuncOp corresponding to the standard malloc call.
+template <typename ModuleOp>
static mlir::SymbolRefAttr
-getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
+getMallocInModule(ModuleOp mod, fir::AllocMemOp op,
+ mlir::ConversionPatternRewriter &rewriter) {
static constexpr char mallocName[] = "malloc";
- auto module = op->getParentOfType<mlir::ModuleOp>();
- if (auto mallocFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName))
+ if (auto mallocFunc =
+ mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName))
return mlir::SymbolRefAttr::get(mallocFunc);
- if (auto userMalloc = module.lookupSymbol<mlir::func::FuncOp>(mallocName))
+ if (auto userMalloc =
+ mod.template lookupSymbol<mlir::func::FuncOp>(mallocName))
return mlir::SymbolRefAttr::get(userMalloc);
- mlir::OpBuilder moduleBuilder(
- op->getParentOfType<mlir::ModuleOp>().getBodyRegion());
+
+ mlir::OpBuilder moduleBuilder(mod.getBodyRegion());
auto indexType = mlir::IntegerType::get(op.getContext(), 64);
auto mallocDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
op.getLoc(), mallocName,
@@ -940,6 +943,15 @@ getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
return mlir::SymbolRefAttr::get(mallocDecl);
}
+/// Return the LLVMFuncOp corresponding to the standard malloc call.
+static mlir::SymbolRefAttr
+getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
+ if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>())
+ return getMallocInModule(mod, op, rewriter);
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ return getMallocInModule(mod, op, rewriter);
+}
+
/// Helper function for generating the LLVM IR that computes the distance
/// in bytes between adjacent elements pointed to by a pointer
/// of type \p ptrTy. The result is returned as a value of \p idxTy integer
@@ -1016,18 +1028,20 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
} // namespace
/// Return the LLVMFuncOp corresponding to the standard free call.
-static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
- mlir::ConversionPatternRewriter &rewriter) {
+template <typename ModuleOp>
+static mlir::SymbolRefAttr
+getFreeInModule(ModuleOp mod, fir::FreeMemOp op,
+ mlir::ConversionPatternRewriter &rewriter) {
static constexpr char freeName[] = "free";
- auto module = op->getParentOfType<mlir::ModuleOp>();
// Check if free already defined in the module.
- if (auto freeFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(freeName))
+ if (auto freeFunc =
+ mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(freeName))
return mlir::SymbolRefAttr::get(freeFunc);
if (auto freeDefinedByUser =
- module.lookupSymbol<mlir::func::FuncOp>(freeName))
+ mod.template lookupSymbol<mlir::func::FuncOp>(freeName))
return mlir::SymbolRefAttr::get(freeDefinedByUser);
// Create llvm declaration for free.
- mlir::OpBuilder moduleBuilder(module.getBodyRegion());
+ mlir::OpBuilder moduleBuilder(mod.getBodyRegion());
auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext());
auto freeDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), freeName,
@@ -1037,6 +1051,14 @@ static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
return mlir::SymbolRefAttr::get(freeDecl);
}
+static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
+ mlir::ConversionPatternRewriter &rewriter) {
+ if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>())
+ return getFreeInModule(mod, op, rewriter);
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ return getFreeInModule(mod, op, rewriter);
+}
+
static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) {
unsigned result = 1;
for (auto eleTy =
@@ -3730,6 +3752,7 @@ class FIRToLLVMLowering
mlir::configureOpenMPToLLVMConversionLegality(target, typeConverter);
target.addLegalDialect<mlir::omp::OpenMPDialect>();
target.addLegalDialect<mlir::acc::OpenACCDialect>();
+ target.addLegalDialect<mlir::gpu::GPUDialect>();
// required NOPs for applying a full conversion
target.addLegalOp<mlir::ModuleOp>();
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index fa391fa6cc7a7d..4c9f965e1241a0 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2776,3 +2776,19 @@ func.func @coordinate_array_unknown_size_1d(%arg0: !fir.ptr<!fir.array<? x i32>>
fir.global common @c_(dense<0> : vector<4294967296xi8>) : !fir.array<4294967296xi8>
// CHECK: llvm.mlir.global common @c_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<4294967296 x i8>
+
+// -----
+
+gpu.module @cuda_device_mod {
+ gpu.func @test_alloc_and_freemem_one() {
+ %z0 = fir.allocmem i32
+ fir.freemem %z0 : !fir.heap<i32>
+ gpu.return
+ }
+}
+
+// CHECK: gpu.module @cuda_device_mod {
+// CHECK: llvm.func @free(!llvm.ptr)
+// CHECK: llvm.func @malloc(i64) -> !llvm.ptr
+// CHECK: llvm.call @malloc
+// CHECK: lvm.call @free
``````````
</details>
https://github.com/llvm/llvm-project/pull/116112
More information about the flang-commits
mailing list