[flang-commits] [flang] [flang][cuda] Support malloc and free conversion in gpu module (PR #116112)

via flang-commits flang-commits at lists.llvm.org
Wed Nov 13 14:00:07 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

Update `getMalloc` and `getFree` to work with the enclosing module (ModuleOp or GPUModuleOp) so we can convert `fir.allocmem` and `fir.freemem` in device code. 

---
Full diff: https://github.com/llvm/llvm-project/pull/116112.diff


3 Files Affected:

- (modified) flang/lib/Optimizer/CodeGen/CMakeLists.txt (+1) 
- (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+36-13) 
- (modified) flang/test/Fir/convert-to-llvm.fir (+16) 


``````````diff
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 646621cb01c157..f47d11875f04db 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -23,6 +23,7 @@ add_flang_library(FIRCodeGen
   FIRSupport
   MLIRComplexToLLVM
   MLIRComplexToStandard
+  MLIRGPUDialect
   MLIRMathToFuncs
   MLIRMathToLLVM
   MLIRMathToLibm
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index d038efcb2eb42c..3452a662f7a194 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -41,6 +41,7 @@
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
@@ -920,17 +921,19 @@ struct EmboxCharOpConversion : public fir::FIROpConversion<fir::EmboxCharOp> {
 };
 } // namespace
 
-/// Return the LLVMFuncOp corresponding to the standard malloc call.
+template <typename ModuleOp>
 static mlir::SymbolRefAttr
-getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
+getMallocInModule(ModuleOp mod, fir::AllocMemOp op,
+                  mlir::ConversionPatternRewriter &rewriter) {
   static constexpr char mallocName[] = "malloc";
-  auto module = op->getParentOfType<mlir::ModuleOp>();
-  if (auto mallocFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName))
+  if (auto mallocFunc =
+          mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName))
     return mlir::SymbolRefAttr::get(mallocFunc);
-  if (auto userMalloc = module.lookupSymbol<mlir::func::FuncOp>(mallocName))
+  if (auto userMalloc =
+          mod.template lookupSymbol<mlir::func::FuncOp>(mallocName))
     return mlir::SymbolRefAttr::get(userMalloc);
-  mlir::OpBuilder moduleBuilder(
-      op->getParentOfType<mlir::ModuleOp>().getBodyRegion());
+
+  mlir::OpBuilder moduleBuilder(mod.getBodyRegion());
   auto indexType = mlir::IntegerType::get(op.getContext(), 64);
   auto mallocDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
       op.getLoc(), mallocName,
@@ -940,6 +943,15 @@ getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
   return mlir::SymbolRefAttr::get(mallocDecl);
 }
 
+/// Return the LLVMFuncOp corresponding to the standard malloc call.
+static mlir::SymbolRefAttr
+getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
+  if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>())
+    return getMallocInModule(mod, op, rewriter);
+  auto mod = op->getParentOfType<mlir::ModuleOp>();
+  return getMallocInModule(mod, op, rewriter);
+}
+
 /// Helper function for generating the LLVM IR that computes the distance
 /// in bytes between adjacent elements pointed to by a pointer
 /// of type \p ptrTy. The result is returned as a value of \p idxTy integer
@@ -1016,18 +1028,20 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
 } // namespace
 
 /// Return the LLVMFuncOp corresponding to the standard free call.
-static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
-                                   mlir::ConversionPatternRewriter &rewriter) {
+template <typename ModuleOp>
+static mlir::SymbolRefAttr
+getFreeInModule(ModuleOp mod, fir::FreeMemOp op,
+                mlir::ConversionPatternRewriter &rewriter) {
   static constexpr char freeName[] = "free";
-  auto module = op->getParentOfType<mlir::ModuleOp>();
   // Check if free already defined in the module.
-  if (auto freeFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(freeName))
+  if (auto freeFunc =
+          mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(freeName))
     return mlir::SymbolRefAttr::get(freeFunc);
   if (auto freeDefinedByUser =
-          module.lookupSymbol<mlir::func::FuncOp>(freeName))
+          mod.template lookupSymbol<mlir::func::FuncOp>(freeName))
     return mlir::SymbolRefAttr::get(freeDefinedByUser);
   // Create llvm declaration for free.
-  mlir::OpBuilder moduleBuilder(module.getBodyRegion());
+  mlir::OpBuilder moduleBuilder(mod.getBodyRegion());
   auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext());
   auto freeDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
       rewriter.getUnknownLoc(), freeName,
@@ -1037,6 +1051,14 @@ static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
   return mlir::SymbolRefAttr::get(freeDecl);
 }
 
+static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
+                                   mlir::ConversionPatternRewriter &rewriter) {
+  if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>())
+    return getFreeInModule(mod, op, rewriter);
+  auto mod = op->getParentOfType<mlir::ModuleOp>();
+  return getFreeInModule(mod, op, rewriter);
+}
+
 static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) {
   unsigned result = 1;
   for (auto eleTy =
@@ -3730,6 +3752,7 @@ class FIRToLLVMLowering
     mlir::configureOpenMPToLLVMConversionLegality(target, typeConverter);
     target.addLegalDialect<mlir::omp::OpenMPDialect>();
     target.addLegalDialect<mlir::acc::OpenACCDialect>();
+    target.addLegalDialect<mlir::gpu::GPUDialect>();
 
     // required NOPs for applying a full conversion
     target.addLegalOp<mlir::ModuleOp>();
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index fa391fa6cc7a7d..4c9f965e1241a0 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2776,3 +2776,19 @@ func.func @coordinate_array_unknown_size_1d(%arg0: !fir.ptr<!fir.array<? x i32>>
 fir.global common @c_(dense<0> : vector<4294967296xi8>) : !fir.array<4294967296xi8>
 
 // CHECK: llvm.mlir.global common @c_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<4294967296 x i8>
+
+// -----
+
+gpu.module @cuda_device_mod {
+  gpu.func @test_alloc_and_freemem_one() {
+    %z0 = fir.allocmem i32
+    fir.freemem %z0 : !fir.heap<i32>
+    gpu.return
+  }
+}
+
+// CHECK: gpu.module @cuda_device_mod {
+// CHECK: llvm.func @free(!llvm.ptr)
+// CHECK: llvm.func @malloc(i64) -> !llvm.ptr
+// CHECK: llvm.call @malloc
+// CHECK: lvm.call @free

``````````

</details>


https://github.com/llvm/llvm-project/pull/116112


More information about the flang-commits mailing list