[Mlir-commits] [mlir] b2499bf - [mlir][bufferize][NFC] Refactor createAlloc function signature
Matthias Springer
llvmlistbot at llvm.org
Mon Jan 24 03:25:59 PST 2022
Author: Matthias Springer
Date: 2022-01-24T20:25:35+09:00
New Revision: b2499bf3e851c67ef623766b922de520de9235d5
URL: https://github.com/llvm/llvm-project/commit/b2499bf3e851c67ef623766b922de520de9235d5
DIFF: https://github.com/llvm/llvm-project/commit/b2499bf3e851c67ef623766b922de520de9235d5.diff
LOG: [mlir][bufferize][NFC] Refactor createAlloc function signature
Pass a ValueRange instead of an ArrayRef<Value> for better compatibility. Also provide an additional function overload that automatically deallocates the buffer if specified.
Differential Revision: https://reviews.llvm.org/D118025
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index f679a22fa7a6c..bbac6e59aeeb2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -36,8 +36,8 @@ class BufferizationState;
/// Options for ComprehensiveBufferize.
struct BufferizationOptions {
- using AllocationFn = std::function<FailureOr<Value>(
- OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
+ using AllocationFn = std::function<FailureOr<Value>(OpBuilder &, Location,
+ MemRefType, ValueRange)>;
using DeallocationFn =
std::function<LogicalResult(OpBuilder &, Location, Value)>;
using MemCpyFn =
@@ -298,15 +298,23 @@ UnrankedMemRefType getUnrankedMemRefType(Type elementType,
MemRefType getDynamicMemRefType(RankedTensorType tensorType,
unsigned addressSpace = 0);
-/// Creates a memref allocation.
+/// Creates a memref allocation with the given type and dynamic extents.
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
- ArrayRef<Value> dynShape,
+ ValueRange dynShape,
+ const BufferizationOptions &options);
+
+/// Creates a memref allocation with the given type and dynamic extents. If
+/// `createDealloc`, a deallocation op is inserted at the point where the
+/// allocation goes out of scope.
+FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
+ ValueRange dynShape, bool deallocMemref,
const BufferizationOptions &options);
/// Creates a memref allocation for the given shaped value. This function may
/// perform additional optimizations such as buffer allocation hoisting. If
/// `createDealloc`, a deallocation op is inserted at the point where the
/// allocation goes out of scope.
+// TODO: Allocation hoisting should be a cleanup pass.
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
bool deallocMemref,
const BufferizationOptions &options);
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index fb081d3d6c3cd..e565f41a39d5a 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -433,10 +433,10 @@ bufferization::createAlloc(OpBuilder &b, Location loc, Value shapedValue,
return casted;
}
-/// Create a memref allocation.
+/// Create a memref allocation with the given type and dynamic extents.
FailureOr<Value>
bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
- ArrayRef<Value> dynShape,
+ ValueRange dynShape,
const BufferizationOptions &options) {
if (options.allocationFn)
return (*options.allocationFn)(b, loc, type, dynShape);
@@ -447,6 +447,28 @@ bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
return allocated;
}
+/// Create a memref allocation with the given type and dynamic extents. May also
+/// deallocate the memref again.
+FailureOr<Value>
+bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
+ ValueRange dynShape, bool deallocMemref,
+ const BufferizationOptions &options) {
+ OpBuilder::InsertionGuard g(b);
+
+ FailureOr<Value> alloc = createAlloc(b, loc, type, dynShape, options);
+ if (failed(alloc))
+ return failure();
+
+ if (deallocMemref) {
+ // Dealloc at the end of the block.
+ b.setInsertionPoint(alloc.getValue().getParentBlock()->getTerminator());
+ if (failed(createDealloc(b, loc, *alloc, options)))
+ return failure();
+ }
+
+ return alloc;
+}
+
/// Create a memref deallocation.
LogicalResult
bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 3c8b9c9606952..9409492e12dba 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -73,7 +73,7 @@ static void applyEnablingTransformations(ModuleOp moduleOp) {
static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
MemRefType type,
- ArrayRef<Value> dynShape) {
+ ValueRange dynShape) {
Value allocated = b.create<memref::AllocaOp>(
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
return allocated;
More information about the Mlir-commits
mailing list