[Mlir-commits] [mlir] 82ea0d8 - [mlir][bufferize] Support alloc hoisting across function boundaries

Matthias Springer llvmlistbot at llvm.org
Thu May 12 00:47:25 PDT 2022


Author: Matthias Springer
Date: 2022-05-12T09:44:07+02:00
New Revision: 82ea0d8b824892fd04f82416782369cd77da836a

URL: https://github.com/llvm/llvm-project/commit/82ea0d8b824892fd04f82416782369cd77da836a
DIFF: https://github.com/llvm/llvm-project/commit/82ea0d8b824892fd04f82416782369cd77da836a.diff

LOG: [mlir][bufferize] Support alloc hoisting across function boundaries

This change integrates the BufferResultsToOutParamsPass into One-Shot Module Bufferization. This improves memory management (deallocation) when buffers are returned from a function.

Note: This currently only works with statically-sized tensors. The generated code is not very efficient yet and there are opportunities for improvment (fewer copies). By default, this new functionality is deactivated.

Differential Revision: https://reviews.llvm.org/D125376

Added: 
    mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
    mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
    mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 70c0f00fac433..4d008add40cd9 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -221,6 +221,10 @@ struct BufferizationOptions {
   /// For debugging only. Should be used together with `testAnalysisOnly`.
   bool printConflicts = false;
 
+  /// If set to `true`, buffers that are returned from functions are replaced
+  /// with buffer "out" parameters. At the call site, new buffers are allocated.
+  bool promoteBufferResultsToOutParams = false;
+
   /// If set to `true`, an `getAliasingOpResult` will return the corresponding
   /// "out"/"dest" OpOperand for every op that has the notion of an "out"/"dest"
   /// operand. I.e., the aliasing OpOperand of the i-th tensor OpResult is

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 4a72934d4f506..dd39cd528bfaf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -4,6 +4,8 @@
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+class ModuleOp;
+
 namespace func {
 class FuncOp;
 } // namespace func
@@ -33,6 +35,10 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
 /// Creates a pass that converts memref function results to out-params.
 std::unique_ptr<Pass> createBufferResultsToOutParamsPass();
 
+/// Replace buffers that are returned from a function with an out parameter.
+/// Also update all call sites.
+LogicalResult promoteBufferResultsToOutParams(ModuleOp module);
+
 /// Creates a pass that finalizes a partial bufferization by removing remaining
 /// bufferization.to_tensor and bufferization.to_memref operations.
 std::unique_ptr<OperationPass<func::FuncOp>> createFinalizingBufferizePass();

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 1820df8cb9b80..a19f92cca6902 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -264,6 +264,9 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
             /*default=*/"false",
            "Test only: Annotate IR with RaW conflicts. Requires "
            "test-analysis-only.">,
+    Option<"promoteBufferResultsToOutParams",
+           "promote-buffer-results-to-out-params", "bool", /*default=*/"false",
+           "Replace returned buffers (that were not dropped) with out params.">,
   ];
   let constructor = "mlir::bufferization::createOneShotBufferizePass()";
 }

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index ba803ea65bdda..c9018280ea23f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -122,20 +122,25 @@ static LogicalResult updateCalls(ModuleOp module) {
   return failure(didFail);
 }
 
+LogicalResult
+mlir::bufferization::promoteBufferResultsToOutParams(ModuleOp module) {
+  for (auto func : module.getOps<func::FuncOp>()) {
+    SmallVector<BlockArgument, 6> appendedEntryArgs;
+    updateFuncOp(func, appendedEntryArgs);
+    if (func.isExternal())
+      continue;
+    updateReturnOps(func, appendedEntryArgs);
+  }
+  if (failed(updateCalls(module)))
+    return failure();
+  return success();
+}
+
 namespace {
 struct BufferResultsToOutParamsPass
     : BufferResultsToOutParamsBase<BufferResultsToOutParamsPass> {
   void runOnOperation() override {
-    ModuleOp module = getOperation();
-
-    for (auto func : module.getOps<func::FuncOp>()) {
-      SmallVector<BlockArgument, 6> appendedEntryArgs;
-      updateFuncOp(func, appendedEntryArgs);
-      if (func.isExternal())
-        continue;
-      updateReturnOps(func, appendedEntryArgs);
-    }
-    if (failed(updateCalls(module)))
+    if (failed(bufferization::promoteBufferResultsToOutParams(getOperation())))
       return signalPassFailure();
   }
 };

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index ad8ced3fc5acb..eb0aeaba0e65b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -179,6 +179,7 @@ struct OneShotBufferizePass
       opt.printConflicts = printConflicts;
       opt.testAnalysisOnly = testAnalysisOnly;
       opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
+      opt.promoteBufferResultsToOutParams = promoteBufferResultsToOutParams;
 
       BufferizationOptions::OpFilterEntry::FilterFn filterFn =
           [&](Operation *op) {
@@ -263,17 +264,29 @@ bufferization::finalizeBuffers(Operation *op,
   if (failed(hoistBufferAllocations(op, options)))
     return failure();
 
-  // Deallocate buffers that escape block boundaries ("leaking buffers") with
-  // the buffer deallocation pass.
-  bool hasLeakingAlloc = false;
+  // Create allocation ops for "leaking buffers", i.e., buffer allocations that
+  // escape block boundaries. If there are no leaking allocs, `hasLeakingAllocs`
+  // is set to `false`.
+  bool hasLeakingAllocs = false;
   if (failed(createAllocDeallocOps(op, options, /*onlyLeakingAllocs=*/true,
-                                   &hasLeakingAlloc)))
-    return failure();
-  if (options.createDeallocs && hasLeakingAlloc &&
-      failed(deallocateBuffers(op)))
+                                   &hasLeakingAllocs)))
     return failure();
 
-  // Deallocate all remaining buffers at the end of the block.
+  if (hasLeakingAllocs) {
+    // Promote returned buffers to "out" parameters.
+    // TODO: Pass options to support custom dealloc ops.
+    if (options.promoteBufferResultsToOutParams && isa<ModuleOp>(op) &&
+        failed(promoteBufferResultsToOutParams(cast<ModuleOp>(op))))
+      return failure();
+
+    // Create deallocation ops for all "leaking buffers" and all buffer
+    // allocations that were added during the above promotion process.
+    // TODO: Pass options to support custom dealloc ops.
+    if (options.createDeallocs && failed(deallocateBuffers(op)))
+      return failure();
+  }
+
+  // Deallocate all remaining buffers at the end of their parent blocks.
   if (failed(createAllocDeallocOps(op, options)))
     return failure();
 

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir
new file mode 100644
index 0000000000000..517f71b0aef41
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries allow-return-allocs promote-buffer-results-to-out-params" -split-input-file | FileCheck %s
+
+// Note: This bufferization is not very efficient yet, but it works.
+
+// CHECK: #[[$map1:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+// CHECK-LABEL: func @callee(
+//  CHECK-SAME:              %[[arg0:.*]]: memref<5xf32, #[[$map1]]>,
+//  CHECK-SAME:              %[[arg1:.*]]: memref<5xf32>) {
+//       CHECK:   %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32>
+//       CHECK:   memref.copy %[[arg0]], %[[alloc]]
+//       CHECK:   memref.store %{{.*}}, %[[alloc]]
+//       CHECK:   memref.copy %[[alloc]], %[[arg1]]
+//       CHECK:   memref.dealloc %[[alloc]]
+//       CHECK:   return
+//       CHECK: }
+func.func @callee(%t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 8.0 : f32
+  %1 = tensor.insert %cst into %t[%c0] : tensor<5xf32>
+  return %t, %1 : tensor<5xf32>, tensor<5xf32>
+}
+
+// CHECK: func @main(%[[arg0:.*]]: memref<5xf32, #[[$map1]]>) -> (f32, f32) {
+// CHECK:   %[[alloc:.*]] = memref.alloc() : memref<5xf32>
+// CHECK:   call @callee(%[[arg0]], %[[alloc]])
+// CHECK:   %[[l1:.*]] = memref.load %[[arg0]]
+// CHECK:   %[[l2:.*]] = memref.load %[[alloc]]
+// CHECK:   memref.dealloc %[[alloc]]
+// CHECK:   return %[[l1]], %[[l2]]
+// CHECK: }
+func.func @main(%t: tensor<5xf32>) -> (f32, f32) {
+  %c0 = arith.constant 0 : index
+  %0, %1 = func.call @callee(%t)
+      : (tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>)
+  %2 = tensor.extract %0[%c0] : tensor<5xf32>
+  %3 = tensor.extract %1[%c0] : tensor<5xf32>
+  return %2, %3 : f32, f32
+}
+


        


More information about the Mlir-commits mailing list