[Mlir-commits] [mlir] a8601f1 - [MLIR] Generic 'malloc', 'aligned_alloc' and 'free' functions

Alex Zinenko llvmlistbot at llvm.org
Mon Jul 25 06:53:00 PDT 2022


Author: Michele Scuttari
Date: 2022-07-25T15:52:51+02:00
New Revision: a8601f11fbb753e552197c5aa835fd3c30c29fd3

URL: https://github.com/llvm/llvm-project/commit/a8601f11fbb753e552197c5aa835fd3c30c29fd3
DIFF: https://github.com/llvm/llvm-project/commit/a8601f11fbb753e552197c5aa835fd3c30c29fd3.diff

LOG: [MLIR] Generic 'malloc', 'aligned_alloc' and 'free' functions

When converted to the LLVM dialect, the memref.alloc and memref.free operations were generating calls to hardcoded 'malloc' and 'free' functions. This didn't leave any freedom to users to provide their custom implementation. Those operations now convert into calls to '_mlir_alloc' and '_mlir_free' functions, which have also been implemented into the runtime support library as wrappers to 'malloc' and 'free'. The same has been done for the 'aligned_alloc' function.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D128791

Added: 
    mlir/test/Conversion/MemRefToLLVM/generic-functions.mlir

Modified: 
    mlir/docs/TargetLLVMIR.md
    mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/TargetLLVMIR.md b/mlir/docs/TargetLLVMIR.md
index 2c1d222f406ef..c86ed7a9544b5 100644
--- a/mlir/docs/TargetLLVMIR.md
+++ b/mlir/docs/TargetLLVMIR.md
@@ -553,6 +553,17 @@ llvm.func @caller(%arg0: !llvm.ptr<f32>) {
 The "bare pointer" calling convention does not support unranked memrefs as their
 shape cannot be known at compile time.
 
+### Generic alloction and deallocation functions
+
+When converting the Memref dialect, allocations and deallocations are converted
+into calls to `malloc` (`aligned_alloc` if aligned allocations are requested)
+and `free`. However, it is possible to convert them to more generic functions
+which can be implemented by a runtime library, thus allowing custom allocation
+strategies or runtime profiling. When the conversion pass is  instructed to
+perform such operation, the names of the calles are `_mlir_alloc`,
+`_mlir_aligned_alloc` and `_mlir_free`. Their signatures are the same of
+`malloc`, `aligned_alloc` and `free`.
+
 ### C-compatible wrapper emission
 
 In practical cases, it may be desirable to have externally-facing functions with

diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h
index 91f6f3d8addf5..240a144bc6cdb 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h
@@ -48,6 +48,8 @@ class LowerToLLVMOptions {
 
   AllocLowering allocLowering = AllocLowering::Malloc;
 
+  bool useGenericFunctions = false;
+
   /// The data layout of the module to produce. This must be consistent with the
   /// data layout used in the upper levels of the lowering pipeline.
   // TODO: this should be replaced by MLIR data layout when one exists.

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d2a076f21ac2d..8230e9dc18730 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -522,6 +522,11 @@ def ConvertMemRefToLLVM : Pass<"convert-memref-to-llvm", "ModuleOp"> {
     Option<"indexBitwidth", "index-bitwidth", "unsigned",
            /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
            "Bitwidth of the index type, 0 to use size of machine word">,
+    Option<"useGenericFunctions", "use-generic-functions",
+           "bool",
+           /*default=*/"false",
+           "Use generic allocation and deallocation functions instead of the "
+           "classic 'malloc', 'aligned_alloc' and 'free' functions">
   ];
 }
 

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 6380ff2d8e132..7a9167c5151f2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -45,6 +45,11 @@ LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType);
 LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
                                               Type indexType);
 LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
+                                              Type indexType);
+LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
+                                                     Type indexType);
+LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(ModuleOp moduleOp);
 LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
                                             Type unrankedDescriptorType);
 

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 7abc107e9f634..8ae0c3cdee71b 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -35,6 +35,15 @@ struct AllocOpLowering : public AllocLikeOpLLVMLowering {
       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
                                 converter) {}
 
+  LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const {
+    bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
+
+    if (useGenericFn)
+      return LLVM::lookupOrCreateGenericAllocFn(module, getIndexType());
+
+    return LLVM::lookupOrCreateMallocFn(module, getIndexType());
+  }
+
   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
                                           Location loc, Value sizeBytes,
                                           Operation *op) const override {
@@ -61,8 +70,7 @@ struct AllocOpLowering : public AllocLikeOpLLVMLowering {
     // Allocate the underlying buffer and store a pointer to it in the MemRef
     // descriptor.
     Type elementPtrType = this->getElementPtrType(memRefType);
-    auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
-        allocOp->getParentOfType<ModuleOp>(), getIndexType());
+    auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
                                   getVoidPtrType());
     Value allocatedPtr =
@@ -135,6 +143,15 @@ struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
                     llvm::PowerOf2Ceil(eltSizeBytes));
   }
 
+  LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const {
+    bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
+
+    if (useGenericFn)
+      return LLVM::lookupOrCreateGenericAlignedAllocFn(module, getIndexType());
+
+    return LLVM::lookupOrCreateAlignedAllocFn(module, getIndexType());
+  }
+
   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
                                           Location loc, Value sizeBytes,
                                           Operation *op) const override {
@@ -150,8 +167,7 @@ struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
 
     Type elementPtrType = this->getElementPtrType(memRefType);
-    auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
-        allocOp->getParentOfType<ModuleOp>(), getIndexType());
+    auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
     auto results =
         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
                        getVoidPtrType());
@@ -300,11 +316,20 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
   explicit DeallocOpLowering(LLVMTypeConverter &converter)
       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
 
+  LLVM::LLVMFuncOp getFreeFn(ModuleOp module) const {
+    bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
+
+    if (useGenericFn)
+      return LLVM::lookupOrCreateGenericFreeFn(module);
+
+    return LLVM::lookupOrCreateFreeFn(module);
+  }
+
   LogicalResult
   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Insert the `free` declaration if it is not already present.
-    auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
+    auto freeFunc = getFreeFn(op->getParentOfType<ModuleOp>());
     MemRefDescriptor memref(adaptor.getMemref());
     Value casted = rewriter.create<LLVM::BitcastOp>(
         op.getLoc(), getVoidPtrType(),
@@ -2047,6 +2072,9 @@ struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
     options.allocLowering =
         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
                          : LowerToLLVMOptions::AllocLowering::Malloc);
+
+    options.useGenericFunctions = useGenericFunctions;
+
     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
       options.overrideIndexBitwidth(indexBitwidth);
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index c3f8fcb422402..1a336fe29308b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -35,6 +35,9 @@ static constexpr llvm::StringRef kPrintNewline = "printNewline";
 static constexpr llvm::StringRef kMalloc = "malloc";
 static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
 static constexpr llvm::StringRef kFree = "free";
+static constexpr llvm::StringRef kGenericAlloc = "_mlir_alloc";
+static constexpr llvm::StringRef kGenericAlignedAlloc = "_mlir_aligned_alloc";
+static constexpr llvm::StringRef kGenericFree = "_mlir_free";
 static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
 
 /// Generic print function lookupOrCreate helper.
@@ -115,6 +118,28 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
       LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
+                                                          Type indexType) {
+  return LLVM::lookupOrCreateFn(
+      moduleOp, kGenericAlloc, indexType,
+      LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
+}
+
+LLVM::LLVMFuncOp
+mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
+                                                Type indexType) {
+  return LLVM::lookupOrCreateFn(
+      moduleOp, kGenericAlignedAlloc, {indexType, indexType},
+      LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(ModuleOp moduleOp) {
+  return LLVM::lookupOrCreateFn(
+      moduleOp, kGenericFree,
+      LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)),
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
 LLVM::LLVMFuncOp
 mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
                                        Type unrankedDescriptorType) {

diff  --git a/mlir/test/Conversion/MemRefToLLVM/generic-functions.mlir b/mlir/test/Conversion/MemRefToLLVM/generic-functions.mlir
new file mode 100644
index 0000000000000..624ae76c08061
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToLLVM/generic-functions.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -pass-pipeline="convert-memref-to-llvm{use-generic-functions=1}" -split-input-file %s \
+// RUN: | FileCheck %s --check-prefix="CHECK-NOTALIGNED"
+
+// RUN: mlir-opt -pass-pipeline="convert-memref-to-llvm{use-generic-functions=1 use-aligned-alloc=1}" -split-input-file %s \
+// RUN: | FileCheck %s --check-prefix="CHECK-ALIGNED"
+
+// CHECK-LABEL: func @alloc()
+func.func @zero_d_alloc() -> memref<f32> {
+// CHECK-NOTALIGNED: llvm.call @_mlir_alloc(%{{.*}}) : (i64) -> !llvm.ptr<i8>
+// CHECK-ALIGNED: llvm.call @_mlir_aligned_alloc(%{{.*}}, %{{.*}}) : (i64, i64) -> !llvm.ptr<i8>
+  %0 = memref.alloc() : memref<f32>
+  return %0 : memref<f32>
+}
+
+// -----
+
+// CHECK-LABEL: func @dealloc()
+func.func @dealloc(%arg0: memref<f32>) {
+// CHECK-NOTALIGNED: llvm.call @_mlir_free(%{{.*}}) : (!llvm.ptr<i8>) -> ()
+// CHECK-ALIGNED: llvm.call @_mlir_free(%{{.*}}) : (!llvm.ptr<i8>) -> ()
+  memref.dealloc %arg0 : memref<f32>
+  return
+}


        


More information about the Mlir-commits mailing list