[Mlir-commits] [mlir] 82ea0d8 - [mlir][bufferize] Support alloc hoisting across function boundaries
Matthias Springer
llvmlistbot at llvm.org
Thu May 12 00:47:25 PDT 2022
Author: Matthias Springer
Date: 2022-05-12T09:44:07+02:00
New Revision: 82ea0d8b824892fd04f82416782369cd77da836a
URL: https://github.com/llvm/llvm-project/commit/82ea0d8b824892fd04f82416782369cd77da836a
DIFF: https://github.com/llvm/llvm-project/commit/82ea0d8b824892fd04f82416782369cd77da836a.diff
LOG: [mlir][bufferize] Support alloc hoisting across function boundaries
This change integrates the BufferResultsToOutParamsPass into One-Shot Module Bufferization. This improves memory management (deallocation) when buffers are returned from a function.
Note: This currently only works with statically-sized tensors. The generated code is not very efficient yet and there are opportunities for improvment (fewer copies). By default, this new functionality is deactivated.
Differential Revision: https://reviews.llvm.org/D125376
Added:
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 70c0f00fac433..4d008add40cd9 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -221,6 +221,10 @@ struct BufferizationOptions {
/// For debugging only. Should be used together with `testAnalysisOnly`.
bool printConflicts = false;
+ /// If set to `true`, buffers that are returned from functions are replaced
+ /// with buffer "out" parameters. At the call site, new buffers are allocated.
+ bool promoteBufferResultsToOutParams = false;
+
/// If set to `true`, an `getAliasingOpResult` will return the corresponding
/// "out"/"dest" OpOperand for every op that has the notion of an "out"/"dest"
/// operand. I.e., the aliasing OpOperand of the i-th tensor OpResult is
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 4a72934d4f506..dd39cd528bfaf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -4,6 +4,8 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
+class ModuleOp;
+
namespace func {
class FuncOp;
} // namespace func
@@ -33,6 +35,10 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
/// Creates a pass that converts memref function results to out-params.
std::unique_ptr<Pass> createBufferResultsToOutParamsPass();
+/// Replace buffers that are returned from a function with an out parameter.
+/// Also update all call sites.
+LogicalResult promoteBufferResultsToOutParams(ModuleOp module);
+
/// Creates a pass that finalizes a partial bufferization by removing remaining
/// bufferization.to_tensor and bufferization.to_memref operations.
std::unique_ptr<OperationPass<func::FuncOp>> createFinalizingBufferizePass();
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 1820df8cb9b80..a19f92cca6902 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -264,6 +264,9 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
/*default=*/"false",
"Test only: Annotate IR with RaW conflicts. Requires "
"test-analysis-only.">,
+ Option<"promoteBufferResultsToOutParams",
+ "promote-buffer-results-to-out-params", "bool", /*default=*/"false",
+ "Replace returned buffers (that were not dropped) with out params.">,
];
let constructor = "mlir::bufferization::createOneShotBufferizePass()";
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index ba803ea65bdda..c9018280ea23f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -122,20 +122,25 @@ static LogicalResult updateCalls(ModuleOp module) {
return failure(didFail);
}
+LogicalResult
+mlir::bufferization::promoteBufferResultsToOutParams(ModuleOp module) {
+ for (auto func : module.getOps<func::FuncOp>()) {
+ SmallVector<BlockArgument, 6> appendedEntryArgs;
+ updateFuncOp(func, appendedEntryArgs);
+ if (func.isExternal())
+ continue;
+ updateReturnOps(func, appendedEntryArgs);
+ }
+ if (failed(updateCalls(module)))
+ return failure();
+ return success();
+}
+
namespace {
struct BufferResultsToOutParamsPass
: BufferResultsToOutParamsBase<BufferResultsToOutParamsPass> {
void runOnOperation() override {
- ModuleOp module = getOperation();
-
- for (auto func : module.getOps<func::FuncOp>()) {
- SmallVector<BlockArgument, 6> appendedEntryArgs;
- updateFuncOp(func, appendedEntryArgs);
- if (func.isExternal())
- continue;
- updateReturnOps(func, appendedEntryArgs);
- }
- if (failed(updateCalls(module)))
+ if (failed(bufferization::promoteBufferResultsToOutParams(getOperation())))
return signalPassFailure();
}
};
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index ad8ced3fc5acb..eb0aeaba0e65b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -179,6 +179,7 @@ struct OneShotBufferizePass
opt.printConflicts = printConflicts;
opt.testAnalysisOnly = testAnalysisOnly;
opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
+ opt.promoteBufferResultsToOutParams = promoteBufferResultsToOutParams;
BufferizationOptions::OpFilterEntry::FilterFn filterFn =
[&](Operation *op) {
@@ -263,17 +264,29 @@ bufferization::finalizeBuffers(Operation *op,
if (failed(hoistBufferAllocations(op, options)))
return failure();
- // Deallocate buffers that escape block boundaries ("leaking buffers") with
- // the buffer deallocation pass.
- bool hasLeakingAlloc = false;
+ // Create allocation ops for "leaking buffers", i.e., buffer allocations that
+ // escape block boundaries. If there are no leaking allocs, `hasLeakingAllocs`
+ // is set to `false`.
+ bool hasLeakingAllocs = false;
if (failed(createAllocDeallocOps(op, options, /*onlyLeakingAllocs=*/true,
- &hasLeakingAlloc)))
- return failure();
- if (options.createDeallocs && hasLeakingAlloc &&
- failed(deallocateBuffers(op)))
+ &hasLeakingAllocs)))
return failure();
- // Deallocate all remaining buffers at the end of the block.
+ if (hasLeakingAllocs) {
+ // Promote returned buffers to "out" parameters.
+ // TODO: Pass options to support custom dealloc ops.
+ if (options.promoteBufferResultsToOutParams && isa<ModuleOp>(op) &&
+ failed(promoteBufferResultsToOutParams(cast<ModuleOp>(op))))
+ return failure();
+
+ // Create deallocation ops for all "leaking buffers" and all buffer
+ // allocations that were added during the above promotion process.
+ // TODO: Pass options to support custom dealloc ops.
+ if (options.createDeallocs && failed(deallocateBuffers(op)))
+ return failure();
+ }
+
+ // Deallocate all remaining buffers at the end of their parent blocks.
if (failed(createAllocDeallocOps(op, options)))
return failure();
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir
new file mode 100644
index 0000000000000..517f71b0aef41
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries allow-return-allocs promote-buffer-results-to-out-params" -split-input-file | FileCheck %s
+
+// Note: This bufferization is not very efficient yet, but it works.
+
+// CHECK: #[[$map1:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+// CHECK-LABEL: func @callee(
+// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, #[[$map1]]>,
+// CHECK-SAME: %[[arg1:.*]]: memref<5xf32>) {
+// CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32>
+// CHECK: memref.copy %[[arg0]], %[[alloc]]
+// CHECK: memref.store %{{.*}}, %[[alloc]]
+// CHECK: memref.copy %[[alloc]], %[[arg1]]
+// CHECK: memref.dealloc %[[alloc]]
+// CHECK: return
+// CHECK: }
+func.func @callee(%t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 8.0 : f32
+ %1 = tensor.insert %cst into %t[%c0] : tensor<5xf32>
+ return %t, %1 : tensor<5xf32>, tensor<5xf32>
+}
+
+// CHECK: func @main(%[[arg0:.*]]: memref<5xf32, #[[$map1]]>) -> (f32, f32) {
+// CHECK: %[[alloc:.*]] = memref.alloc() : memref<5xf32>
+// CHECK: call @callee(%[[arg0]], %[[alloc]])
+// CHECK: %[[l1:.*]] = memref.load %[[arg0]]
+// CHECK: %[[l2:.*]] = memref.load %[[alloc]]
+// CHECK: memref.dealloc %[[alloc]]
+// CHECK: return %[[l1]], %[[l2]]
+// CHECK: }
+func.func @main(%t: tensor<5xf32>) -> (f32, f32) {
+ %c0 = arith.constant 0 : index
+ %0, %1 = func.call @callee(%t)
+ : (tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>)
+ %2 = tensor.extract %0[%c0] : tensor<5xf32>
+ %3 = tensor.extract %1[%c0] : tensor<5xf32>
+ return %2, %3 : f32, f32
+}
+
More information about the Mlir-commits
mailing list