[Mlir-commits] [mlir] be8742b - [mlir][linalg][bufferize][NFC] Merge AllocationCallbacks into BufferizationOptions
Matthias Springer
llvmlistbot at llvm.org
Wed Jan 19 01:43:19 PST 2022
Author: Matthias Springer
Date: 2022-01-19T18:36:34+09:00
New Revision: be8742b6c9c7d919db33212a27dc230632dbb1d3
URL: https://github.com/llvm/llvm-project/commit/be8742b6c9c7d919db33212a27dc230632dbb1d3
DIFF: https://github.com/llvm/llvm-project/commit/be8742b6c9c7d919db33212a27dc230632dbb1d3.diff
LOG: [mlir][linalg][bufferize][NFC] Merge AllocationCallbacks into BufferizationOptions
Also move `createAlloc` and related helper functions out of BufferizationState. The goal is to make BufferizationState as small as possible. (Code cleanup)
Differential Revision: https://reviews.llvm.org/D117476
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 807d86331401..8fb05125deeb 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -38,34 +38,6 @@ struct BufferizationOptions;
class BufferizationState;
struct PostAnalysisStep;
-/// Callback functions that are used to allocate/deallocate/copy memory buffers.
-/// Comprehensive Bufferize provides default implementations of these functions.
-// TODO: Could be replaced with a "bufferization strategy" object with virtual
-// functions in the future.
-struct AllocationCallbacks {
- using AllocationFn = std::function<FailureOr<Value>(
- OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
- using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
- using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
-
- AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
- MemCpyFn copyFn)
- : allocationFn(std::move(allocFn)), deallocationFn(std::move(deallocFn)),
- memCpyFn(std::move(copyFn)) {}
-
- /// A function that allocates memory.
- AllocationFn allocationFn;
-
- /// A function that deallocated memory. Must be allocated by `allocationFn`.
- DeallocationFn deallocationFn;
-
- /// A function that copies memory between two allocations.
- MemCpyFn memCpyFn;
-};
-
-/// Return default allocation callbacks.
-std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
-
/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
/// executed after the analysis, but before bufferization. They can be used to
/// implement custom dialect-specific optimizations.
@@ -84,6 +56,13 @@ using PostAnalysisStepList = std::vector<std::unique_ptr<PostAnalysisStep>>;
/// Options for ComprehensiveBufferize.
struct BufferizationOptions {
+ using AllocationFn = std::function<FailureOr<Value>(
+ OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
+ using DeallocationFn =
+ std::function<LogicalResult(OpBuilder &, Location, Value)>;
+ using MemCpyFn =
+ std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
+
BufferizationOptions();
// BufferizationOptions cannot be copied.
@@ -126,7 +105,9 @@ struct BufferizationOptions {
BufferizableOpInterface dynCastBufferizableOp(Value value) const;
/// Helper functions for allocation, deallocation, memory copying.
- std::unique_ptr<AllocationCallbacks> allocationFns;
+ Optional<AllocationFn> allocationFn;
+ Optional<DeallocationFn> deallocationFn;
+ Optional<MemCpyFn> memCpyFn;
/// Specifies whether returning newly allocated memrefs should be allowed.
/// Otherwise, a pass failure is triggered.
@@ -362,24 +343,6 @@ class BufferizationState {
/// is returned regardless of whether it is a memory write or not.
SetVector<Value> findLastPrecedingWrite(Value value) const;
- /// Creates a memref allocation.
- FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
- ArrayRef<Value> dynShape) const;
-
- /// Creates a memref allocation for the given shaped value. This function may
- /// perform additional optimizations such as buffer allocation hoisting. If
- /// `createDealloc`, a deallocation op is inserted at the point where the
- /// allocation goes out of scope.
- FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
- bool deallocMemref) const;
-
- /// Creates a memref deallocation. The given memref buffer must have been
- /// allocated using `createAlloc`.
- void createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer) const;
-
- /// Creates a memcpy between two given buffers.
- void createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const;
-
/// Return `true` if the given OpResult has been decided to bufferize inplace.
bool isInPlace(OpOperand &opOperand) const;
@@ -458,6 +421,28 @@ UnrankedMemRefType getUnrankedMemRefType(Type elementType,
MemRefType getDynamicMemRefType(RankedTensorType tensorType,
unsigned addressSpace = 0);
+/// Creates a memref allocation.
+FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
+ ArrayRef<Value> dynShape,
+ const BufferizationOptions &options);
+
+/// Creates a memref allocation for the given shaped value. This function may
+/// perform additional optimizations such as buffer allocation hoisting. If
+/// `createDealloc`, a deallocation op is inserted at the point where the
+/// allocation goes out of scope.
+FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
+ bool deallocMemref,
+ const BufferizationOptions &options);
+
+/// Creates a memref deallocation. The given memref buffer must have been
+/// allocated using `createAlloc`.
+LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
+ const BufferizationOptions &options);
+
+/// Creates a memcpy between two given buffers.
+LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to,
+ const BufferizationOptions &options);
+
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 048a4a39111f..885a70f56c64 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -39,40 +39,8 @@ using namespace linalg::comprehensive_bufferize;
// BufferizationOptions
//===----------------------------------------------------------------------===//
-/// Default allocation function that is used by the comprehensive bufferization
-/// pass. The default currently creates a ranked memref using `memref.alloc`.
-static FailureOr<Value> defaultAllocationFn(OpBuilder &b, Location loc,
- MemRefType type,
- ArrayRef<Value> dynShape) {
- Value allocated = b.create<memref::AllocOp>(
- loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
- return allocated;
-}
-
-/// Default deallocation function that is used by the comprehensive
-/// bufferization pass. It expects to recieve back the value called from the
-/// `defaultAllocationFn`.
-static void defaultDeallocationFn(OpBuilder &b, Location loc,
- Value allocatedBuffer) {
- b.create<memref::DeallocOp>(loc, allocatedBuffer);
-}
-
-/// Default memory copy function that is used by the comprehensive bufferization
-/// pass. Creates a `memref.copy` op.
-static void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to) {
- b.create<memref::CopyOp>(loc, from, to);
-}
-
-std::unique_ptr<AllocationCallbacks>
-mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() {
- return std::make_unique<AllocationCallbacks>(
- defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn);
-}
-
-// Default constructor for BufferizationOptions that sets all allocation
-// callbacks to their default functions.
-BufferizationOptions::BufferizationOptions()
- : allocationFns(defaultAllocationCallbacks()) {}
+// Default constructor for BufferizationOptions.
+BufferizationOptions::BufferizationOptions() {}
BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
@@ -393,8 +361,8 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
// allocation should be inserted (in the absence of allocation hoisting).
setInsertionPointAfter(rewriter, operandBuffer);
// Allocate the result buffer.
- FailureOr<Value> resultBuffer =
- createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
+ FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer,
+ options.createDeallocs, options);
if (failed(resultBuffer))
return failure();
// Do not copy if the last preceding writes of `operand` are ops that do
@@ -425,7 +393,9 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
// The copy happens right before the op that is bufferized.
rewriter.setInsertionPoint(op);
}
- createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
+ if (failed(
+ createMemCpy(rewriter, loc, operandBuffer, *resultBuffer, options)))
+ return failure();
return resultBuffer;
}
@@ -545,9 +515,9 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
/// Create an AllocOp/DeallocOp pair, where the AllocOp is after
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
/// bbArg) and the DeallocOp is at the end of the block.
-FailureOr<Value>
-mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
- OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref) const {
+FailureOr<Value> mlir::linalg::comprehensive_bufferize::createAlloc(
+ OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref,
+ const BufferizationOptions &options) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@@ -558,7 +528,8 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
// Note: getAllocationTypeAndShape also sets the insertion point.
MemRefType allocMemRefType =
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
- FailureOr<Value> allocated = createAlloc(b, loc, allocMemRefType, dynShape);
+ FailureOr<Value> allocated =
+ createAlloc(b, loc, allocMemRefType, dynShape, options);
if (failed(allocated))
return failure();
Value casted = allocated.getValue();
@@ -572,30 +543,47 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
if (deallocMemref) {
// 2. Create memory deallocation.
b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
- createDealloc(b, loc, allocated.getValue());
+ if (failed(createDealloc(b, loc, allocated.getValue(), options)))
+ return failure();
}
return casted;
}
/// Create a memref allocation.
-FailureOr<Value>
-mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
- OpBuilder &b, Location loc, MemRefType type,
- ArrayRef<Value> dynShape) const {
- return options.allocationFns->allocationFn(b, loc, type, dynShape);
+FailureOr<Value> mlir::linalg::comprehensive_bufferize::createAlloc(
+ OpBuilder &b, Location loc, MemRefType type, ArrayRef<Value> dynShape,
+ const BufferizationOptions &options) {
+ if (options.allocationFn)
+ return (*options.allocationFn)(b, loc, type, dynShape);
+
+ // Default bufferallocation via AllocOp.
+ Value allocated = b.create<memref::AllocOp>(
+ loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
+ return allocated;
}
/// Create a memref deallocation.
-void mlir::linalg::comprehensive_bufferize::BufferizationState::createDealloc(
- OpBuilder &b, Location loc, Value allocatedBuffer) const {
- return options.allocationFns->deallocationFn(b, loc, allocatedBuffer);
+LogicalResult mlir::linalg::comprehensive_bufferize::createDealloc(
+ OpBuilder &b, Location loc, Value allocatedBuffer,
+ const BufferizationOptions &options) {
+ if (options.deallocationFn)
+ return (*options.deallocationFn)(b, loc, allocatedBuffer);
+
+ // Default buffer deallocation via DeallocOp.
+ b.create<memref::DeallocOp>(loc, allocatedBuffer);
+ return success();
}
/// Create a memory copy between two memref buffers.
-void mlir::linalg::comprehensive_bufferize::BufferizationState::createMemCpy(
- OpBuilder &b, Location loc, Value from, Value to) const {
- return options.allocationFns->memCpyFn(b, loc, from, to);
+LogicalResult mlir::linalg::comprehensive_bufferize::createMemCpy(
+ OpBuilder &b, Location loc, Value from, Value to,
+ const BufferizationOptions &options) {
+ if (options.memCpyFn)
+ return (*options.memCpyFn)(b, loc, from, to);
+
+ b.create<memref::CopyOp>(loc, from, to);
+ return success();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 64bc6920da07..6d153522811d 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -221,9 +221,9 @@ struct InitTensorOpInterface
if (initTensorOp->getUses().empty())
return success();
- FailureOr<Value> alloc = state.createAlloc(
- rewriter, initTensorOp->getLoc(), initTensorOp.result(),
- state.getOptions().createDeallocs);
+ FailureOr<Value> alloc =
+ createAlloc(rewriter, initTensorOp->getLoc(), initTensorOp.result(),
+ state.getOptions().createDeallocs, state.getOptions());
if (failed(alloc))
return failure();
replaceOpWithBufferizedValues(rewriter, op, *alloc);
@@ -367,7 +367,9 @@ struct TiledLoopOpInterface
Value output = std::get<1>(it);
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
newTerminator.getLoc(), output.getType(), std::get<0>(it));
- state.createMemCpy(rewriter, newTerminator.getLoc(), toMemrefOp, output);
+ if (failed(createMemCpy(rewriter, newTerminator.getLoc(), toMemrefOp,
+ output, state.getOptions())))
+ return failure();
}
// Erase old terminator.
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 620328799712..a31832452ea8 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -158,8 +158,8 @@ struct ExtractSliceOpInterface
Value alloc;
if (!inplace) {
FailureOr<Value> allocOrFailure =
- state.createAlloc(rewriter, loc, extractSliceOp.result(),
- state.getOptions().createDeallocs);
+ createAlloc(rewriter, loc, extractSliceOp.result(),
+ state.getOptions().createDeallocs, state.getOptions());
if (failed(allocOrFailure))
return failure();
alloc = *allocOrFailure;
@@ -191,7 +191,9 @@ struct ExtractSliceOpInterface
if (!inplace) {
// Do not copy if the copied data is never read.
if (state.isValueRead(extractSliceOp.result()))
- state.createMemCpy(rewriter, extractSliceOp.getLoc(), subView, alloc);
+ if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
+ alloc, state.getOptions())))
+ return failure();
subView = alloc;
}
@@ -461,7 +463,9 @@ struct InsertSliceOpInterface
// tensor.extract_slice, the copy operation will eventually fold away.
Value srcMemref =
*state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
- state.createMemCpy(rewriter, loc, srcMemref, subView);
+ if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
+ state.getOptions())))
+ return failure();
replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
return success();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index f4e9476727f0..36da9e52ea3c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -77,9 +77,10 @@ static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
void LinalgComprehensiveModuleBufferize::runOnOperation() {
auto options = std::make_unique<BufferizationOptions>();
if (useAlloca) {
- options->allocationFns->allocationFn = allocationFnUsingAlloca;
- options->allocationFns->deallocationFn = [](OpBuilder &b, Location loc,
- Value v) {};
+ options->allocationFn = allocationFnUsingAlloca;
+ options->deallocationFn = [](OpBuilder &b, Location loc, Value v) {
+ return success();
+ };
}
options->allowReturnMemref = allowReturnMemref;
More information about the Mlir-commits
mailing list