[Mlir-commits] [mlir] 05e0495 - [mlir][bufferize][NFC] Deallocate all buffers at the end of bufferization
Matthias Springer
llvmlistbot at llvm.org
Tue Mar 15 01:58:09 PDT 2022
Author: Matthias Springer
Date: 2022-03-15T17:53:53+09:00
New Revision: 05e0495f1d0c6386c8ee30df15d53a7cf4d6bfdc
URL: https://github.com/llvm/llvm-project/commit/05e0495f1d0c6386c8ee30df15d53a7cf4d6bfdc
DIFF: https://github.com/llvm/llvm-project/commit/05e0495f1d0c6386c8ee30df15d53a7cf4d6bfdc.diff
LOG: [mlir][bufferize][NFC] Deallocate all buffers at the end of bufferization
This makes bufferization more modular. This is in preparation of future refactorings.
Differential Revision: https://reviews.llvm.org/D121362
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index c7cfc7241a509..500f7e89b1758 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -415,13 +415,22 @@ struct BufferizationState {
BufferizationState(const AnalysisState &analysisState)
: analysisState(analysisState) {}
+ /// Creates a memref allocation with the given type and dynamic extents.
+ FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
+ ValueRange dynShape);
+
+ /// Creates a memref allocation for the given shaped value. This function may
+ /// perform additional optimizations such as buffer allocation hoisting.
+ // TODO: Allocation hoisting should be a cleanup pass.
+ FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue);
+
/// Return the buffer (memref) for a given OpOperand (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization was decided.
FailureOr<Value>
getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
bool forceInPlace = false,
- Optional<Operation *> customCopyInsertionPoint = None) const;
+ Optional<Operation *> customCopyInsertionPoint = None);
/// Return a reference to the BufferizationOptions.
const BufferizationOptions &getOptions() const {
@@ -477,27 +486,6 @@ BaseMemRefType getMemRefType(TensorType tensorType,
MemRefLayoutAttrInterface layout = {},
Attribute memorySpace = {});
-/// Creates a memref allocation with the given type and dynamic extents.
-FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
- ValueRange dynShape,
- const BufferizationOptions &options);
-
-/// Creates a memref allocation with the given type and dynamic extents. If
-/// `createDealloc`, a deallocation op is inserted at the point where the
-/// allocation goes out of scope.
-FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
- ValueRange dynShape, bool deallocMemref,
- const BufferizationOptions &options);
-
-/// Creates a memref allocation for the given shaped value. This function may
-/// perform additional optimizations such as buffer allocation hoisting. If
-/// `createDealloc`, a deallocation op is inserted at the point where the
-/// allocation goes out of scope.
-// TODO: Allocation hoisting should be a cleanup pass.
-FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
- bool deallocMemref,
- const BufferizationOptions &options);
-
/// Creates a memref deallocation. The given memref buffer must have been
/// allocated using `createAlloc`.
LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
@@ -507,6 +495,10 @@ LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to,
const BufferizationOptions &options);
+/// Finalize all buffer allocations, i.e., create alloc ops as specified in the
+/// bufferization options and deallocate all buffers.
+LogicalResult finalizeBuffers(Operation *op,
+ const BufferizationOptions &options);
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 13b48d09d33ff..85a70a42d283a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -70,13 +70,6 @@ void populateEliminateBufferizeMaterializationsPatterns(
// TODO: Extract `options` from `state` and pass as separate argument.
LogicalResult bufferizeOp(Operation *op, const AnalysisState &analysisState);
-/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
-/// Reuse an existing `BufferizationState`.
-///
-/// Note: This function overload is useful for extending the bufferization.
-LogicalResult bufferizeOp(Operation *op,
- BufferizationState &bufferizationState);
-
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
/// Buffers are duplicated and copied before any tensor use that bufferizes to
/// a memory write.
@@ -87,6 +80,16 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options);
BufferizationOptions getPartialBufferizationOptions();
+//===----------------------------------------------------------------------===//
+// Helper functions for extending Bufferization
+//===----------------------------------------------------------------------===//
+
+/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
+/// Reuse an existing `BufferizationState`.
+///
+/// Note: This function overload is useful for extending the bufferization.
+LogicalResult bufferizeOp(Operation *op,
+ BufferizationState &bufferizationState);
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index d2b8f1de5628f..8a3dbfc960e9b 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -42,6 +42,8 @@ constexpr const ::llvm::StringLiteral
constexpr const ::llvm::StringLiteral
bufferization::BufferizableOpInterface::kInplaceableAttrName;
+static const char *kBufferAllocationAttr = "bufferization.allocation";
+
//===----------------------------------------------------------------------===//
// BufferizationOptions
//===----------------------------------------------------------------------===//
@@ -243,9 +245,10 @@ Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor,
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
-FailureOr<Value> BufferizationState::getBuffer(
- RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace,
- Optional<Operation *> customCopyInsertionPoint) const {
+FailureOr<Value>
+BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
+ bool forceInPlace,
+ Optional<Operation *> customCopyInsertionPoint) {
const BufferizationOptions &options = analysisState.getOptions();
OpBuilder::InsertionGuard guard(rewriter);
Operation *op = opOperand.getOwner();
@@ -261,8 +264,7 @@ FailureOr<Value> BufferizationState::getBuffer(
// allocation should be inserted (in the absence of allocation hoisting).
setInsertionPointAfter(rewriter, operandBuffer);
// Allocate the result buffer.
- FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer,
- options.createDeallocs, options);
+ FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer);
if (failed(resultBuffer))
return failure();
// Do not copy if the last preceding writes of `operand` are ops that do
@@ -358,6 +360,33 @@ bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1,
// Bufferization-specific scoped alloc/dealloc insertion support.
//===----------------------------------------------------------------------===//
+/// Create a memref allocation with the given type and dynamic extents.
+static FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
+ ValueRange dynShape,
+ const BufferizationOptions &options) {
+ if (options.allocationFn)
+ return (*options.allocationFn)(b, loc, type, dynShape,
+ options.bufferAlignment);
+
+ // Default bufferallocation via AllocOp.
+ Value allocated = b.create<memref::AllocOp>(
+ loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment));
+ return allocated;
+}
+
+/// Creates a memref deallocation. The given memref buffer must have been
+/// allocated using `createAlloc`.
+LogicalResult
+bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
+ const BufferizationOptions &options) {
+ if (options.deallocationFn)
+ return (*options.deallocationFn)(b, loc, allocatedBuffer);
+
+ // Default buffer deallocation via DeallocOp.
+ b.create<memref::DeallocOp>(loc, allocatedBuffer);
+ return success();
+}
+
/// Move the insertion point of the given builder to the beginning of a
/// surrounding block as much as possible, while not crossing any allocation
/// hoisting barriers.
@@ -436,92 +465,39 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
return allocMemRefType;
}
-/// Create an AllocOp/DeallocOp pair, where the AllocOp is after
-/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
-/// bbArg) and the DeallocOp is at the end of the block.
-FailureOr<Value>
-bufferization::createAlloc(OpBuilder &b, Location loc, Value shapedValue,
- bool deallocMemref,
- const BufferizationOptions &options) {
+static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type,
+ ValueRange dynShape) {
+ auto allocaOp = b.create<memref::AllocaOp>(loc, type, dynShape);
+ allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr());
+ return allocaOp.getResult();
+}
+
+/// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
+/// block in case of a bbArg).
+FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
+ Value shapedValue) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
-
- // 1. Create memory allocation.
assert(shapedValue.getType().isa<ShapedType>());
MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
SmallVector<Value> dynShape;
// Note: getAllocationTypeAndShape also sets the insertion point.
MemRefType allocMemRefType =
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
- FailureOr<Value> allocated =
- createAlloc(b, loc, allocMemRefType, dynShape, options);
- if (failed(allocated))
- return failure();
- Value casted = allocated.getValue();
+ Value alloc = createBufferAllocation(b, loc, allocMemRefType, dynShape);
if (memRefType && memRefType != allocMemRefType) {
- assert(memref::CastOp::areCastCompatible(allocated.getValue().getType(),
- memRefType) &&
+ assert(memref::CastOp::areCastCompatible(alloc.getType(), memRefType) &&
"createAlloc: cast incompatible");
- casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
- }
-
- if (deallocMemref) {
- // 2. Create memory deallocation.
- b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
- if (failed(createDealloc(b, loc, allocated.getValue(), options)))
- return failure();
- }
-
- return casted;
-}
-
-/// Create a memref allocation with the given type and dynamic extents.
-FailureOr<Value>
-bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
- ValueRange dynShape,
- const BufferizationOptions &options) {
- if (options.allocationFn)
- return (*options.allocationFn)(b, loc, type, dynShape,
- options.bufferAlignment);
-
- // Default bufferallocation via AllocOp.
- Value allocated = b.create<memref::AllocOp>(
- loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment));
- return allocated;
-}
-
-/// Create a memref allocation with the given type and dynamic extents. May also
-/// deallocate the memref again.
-FailureOr<Value>
-bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
- ValueRange dynShape, bool deallocMemref,
- const BufferizationOptions &options) {
- OpBuilder::InsertionGuard g(b);
-
- FailureOr<Value> alloc = createAlloc(b, loc, type, dynShape, options);
- if (failed(alloc))
- return failure();
-
- if (deallocMemref) {
- // Dealloc at the end of the block.
- b.setInsertionPoint(alloc.getValue().getParentBlock()->getTerminator());
- if (failed(createDealloc(b, loc, *alloc, options)))
- return failure();
+ alloc = b.create<memref::CastOp>(loc, memRefType, alloc);
}
-
return alloc;
}
-/// Create a memref deallocation.
-LogicalResult
-bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
- const BufferizationOptions &options) {
- if (options.deallocationFn)
- return (*options.deallocationFn)(b, loc, allocatedBuffer);
-
- // Default buffer deallocation via DeallocOp.
- b.create<memref::DeallocOp>(loc, allocatedBuffer);
- return success();
+/// Create a memref allocation with the given type and dynamic extents.
+FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
+ MemRefType type,
+ ValueRange dynShape) {
+ return createBufferAllocation(b, loc, type, dynShape);
}
/// Create a memory copy between two memref buffers.
@@ -535,6 +511,41 @@ LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc,
return success();
}
+LogicalResult
+bufferization::finalizeBuffers(Operation *op,
+ const BufferizationOptions &options) {
+ IRRewriter rewriter(op->getContext());
+
+ // Bufferization creates memref.alloca ops. After bufferization, these must be
+ // rewritten to alloc/dealloc ops as specified in the bufferization options.
+ WalkResult status = op->walk([&](memref::AllocaOp allocaOp) {
+ // Ignore memref.alloca ops that were not created by the bufferization.
+ if (!allocaOp->hasAttr(kBufferAllocationAttr))
+ return WalkResult::skip();
+
+ Block *block = allocaOp->getBlock();
+ rewriter.setInsertionPoint(allocaOp);
+ FailureOr<Value> alloc =
+ createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(),
+ allocaOp.dynamicSizes(), options);
+ if (failed(alloc))
+ return WalkResult::interrupt();
+ rewriter.replaceOp(allocaOp, *alloc);
+
+ // Stop here if deallocations are deactivated.
+ if (!options.createDeallocs)
+ return WalkResult::advance();
+
+ rewriter.setInsertionPoint(block->getTerminator());
+ if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options)))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+
+ return success(!status.wasInterrupted());
+}
+
//===----------------------------------------------------------------------===//
// Bufferization-specific BlockAndValueMapping support with debugging.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index ba6cd37751119..f237cb7a6a70e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -302,7 +302,11 @@ checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
LogicalResult bufferization::bufferizeOp(Operation *op,
const AnalysisState &analysisState) {
BufferizationState bufferizationState(analysisState);
- return bufferizeOp(op, bufferizationState);
+ if (failed(bufferizeOp(op, bufferizationState)))
+ return failure();
+ if (failed(finalizeBuffers(op, analysisState.getOptions())))
+ return failure();
+ return success();
}
LogicalResult
@@ -332,7 +336,10 @@ bufferization::bufferizeOp(Operation *op,
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
return failure();
- return checkBufferizationResult(op, bufferizationState.getOptions());
+ if (failed(checkBufferizationResult(op, bufferizationState.getOptions())))
+ return failure();
+
+ return success();
}
namespace {
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 85c425623eacc..333c874761049 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -1054,6 +1054,10 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
}
}
+ // Finalize all buffers.
+ if (failed(finalizeBuffers(moduleOp, options)))
+ return failure();
+
// Perform a post-processing pass of layout modification at function boundary
// according to the kBufferLayoutAttrName.
layoutPostProcessing(moduleOp);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 9153bce176fa0..2cbed7cebb1e9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -235,9 +235,8 @@ struct InitTensorOpInterface
if (initTensorOp->getUses().empty())
return success();
- FailureOr<Value> alloc =
- createAlloc(rewriter, initTensorOp->getLoc(), initTensorOp.result(),
- state.getOptions().createDeallocs, state.getOptions());
+ FailureOr<Value> alloc = state.createAlloc(rewriter, initTensorOp->getLoc(),
+ initTensorOp.result());
if (failed(alloc))
return failure();
replaceOpWithBufferizedValues(rewriter, op, *alloc);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index f068010a50895..67e28c46f3969 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -228,8 +228,7 @@ struct ExtractSliceOpInterface
Value alloc;
if (!inplace) {
FailureOr<Value> allocOrFailure =
- createAlloc(rewriter, loc, extractSliceOp.result(),
- state.getOptions().createDeallocs, state.getOptions());
+ state.createAlloc(rewriter, loc, extractSliceOp.result());
if (failed(allocOrFailure))
return failure();
alloc = *allocOrFailure;
@@ -338,9 +337,7 @@ struct FromElementsOpInterface
auto shape = tensorType.getShape();
MemRefType resultType = getContiguousMemRefType(tensorType);
FailureOr<Value> maybeBuffer =
- createAlloc(rewriter, loc, resultType, {},
- /*deallocMemref=*/state.getOptions().createDeallocs,
- state.getOptions());
+ state.createAlloc(rewriter, loc, resultType, {});
if (failed(maybeBuffer))
return failure();
Value buffer = *maybeBuffer;
@@ -389,10 +386,8 @@ struct GenerateOpInterface
Location loc = op->getLoc();
MemRefType memrefType =
getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
- FailureOr<Value> maybeResult =
- createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(),
- /*deallocMemref=*/state.getOptions().createDeallocs,
- state.getOptions());
+ FailureOr<Value> maybeResult = state.createAlloc(
+ rewriter, loc, memrefType, generateOp.dynamicExtents());
if (failed(maybeResult))
return failure();
Value result = *maybeResult;
More information about the Mlir-commits
mailing list