[Mlir-commits] [mlir] 698896c - [mlir][linalg][bufferize][NFC] Change allocationFn return type to FailureOr<Value>
Matthias Springer
llvmlistbot at llvm.org
Thu Jan 6 13:49:40 PST 2022
Author: Matthias Springer
Date: 2022-01-07T06:33:19+09:00
New Revision: 698896cd6c8cc5e865e1715e7c9d82295f82745b
URL: https://github.com/llvm/llvm-project/commit/698896cd6c8cc5e865e1715e7c9d82295f82745b
DIFF: https://github.com/llvm/llvm-project/commit/698896cd6c8cc5e865e1715e7c9d82295f82745b.diff
LOG: [mlir][linalg][bufferize][NFC] Change allocationFn return type to FailureOr<Value>
In addition, all functions that call `allocationFn` now return FailureOr<Value>. This resolves a few TODOs in the code base.
Differential Revision: https://reviews.llvm.org/D116452
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 921353a23ea7..c18f7f9fc5e9 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -41,7 +41,7 @@ struct PostAnalysisStep;
// TODO: Could be replaced with a "bufferization strategy" object with virtual
// functions in the future.
struct AllocationCallbacks {
- using AllocationFn = std::function<Optional<Value>(
+ using AllocationFn = std::function<FailureOr<Value>(
OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
@@ -360,15 +360,15 @@ class BufferizationState {
Value findLastPrecedingWrite(Value value) const;
/// Creates a memref allocation.
- Optional<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
- ArrayRef<Value> dynShape) const;
+ FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
+ ArrayRef<Value> dynShape) const;
/// Creates a memref allocation for the given shaped value. This function may
/// perform additional optimizations such as buffer allocation hoisting. If
/// `createDealloc`, a deallocation op is inserted at the point where the
/// allocation goes out of scope.
- Value createAlloc(OpBuilder &b, Location loc, Value shapedValue,
- bool deallocMemref) const;
+ FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
+ bool deallocMemref) const;
/// Creates a memref deallocation. The given memref buffer must have been
/// allocated using `createAlloc`.
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 118e25a23148..b2a58069e85a 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -41,9 +41,9 @@ using namespace linalg::comprehensive_bufferize;
/// Default allocation function that is used by the comprehensive bufferization
/// pass. The default currently creates a ranked memref using `memref.alloc`.
-static Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
- MemRefType type,
- ArrayRef<Value> dynShape) {
+static FailureOr<Value> defaultAllocationFn(OpBuilder &b, Location loc,
+ MemRefType type,
+ ArrayRef<Value> dynShape) {
Value allocated = b.create<memref::AllocOp>(
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
return allocated;
@@ -391,8 +391,10 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer(
// allocation should be inserted (in the absence of allocation hoisting).
setInsertionPointAfter(rewriter, operandBuffer);
// Allocate the result buffer.
- Value resultBuffer =
+ FailureOr<Value> resultBuffer =
createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
+ if (failed(resultBuffer))
+ return failure();
bool skipCopy = false;
// Do not copy if the last preceding write of `operand` is an op that does
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@@ -413,7 +415,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer(
if (!skipCopy) {
// The copy happens right before the op that is bufferized.
rewriter.setInsertionPoint(op);
- createMemCpy(rewriter, loc, operandBuffer, resultBuffer);
+ createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
}
return resultBuffer;
}
@@ -537,7 +539,8 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
/// Create an AllocOp/DeallocOp pair, where the AllocOp is after
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
/// bbArg) and the DeallocOp is at the end of the block.
-Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
+FailureOr<Value>
+mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref) const {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@@ -549,10 +552,9 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
// Note: getAllocationTypeAndShape also sets the insertion point.
MemRefType allocMemRefType =
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
- Optional<Value> allocated = createAlloc(b, loc, allocMemRefType, dynShape);
- // TODO: For now just assert the value is returned. Eventually need to
- // error-propagate.
- assert(allocated && "allocation failed");
+ FailureOr<Value> allocated = createAlloc(b, loc, allocMemRefType, dynShape);
+ if (failed(allocated))
+ return failure();
Value casted = allocated.getValue();
if (memRefType && memRefType != allocMemRefType) {
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
@@ -568,7 +570,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
}
/// Create a memref allocation.
-Optional<Value>
+FailureOr<Value>
mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
OpBuilder &b, Location loc, MemRefType type,
ArrayRef<Value> dynShape) const {
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index c4f42afb9828..a3cb3c36065b 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -55,6 +55,8 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
OpResult opResult = op.getTiedOpResult(opOperand);
assert(opResult && "could not find correspond OpResult");
FailureOr<Value> resultBuffer = state.getResultBuffer(rewriter, opResult);
+ if (failed(resultBuffer))
+ return failure();
newOutputBuffers.push_back(*resultBuffer);
}
@@ -210,10 +212,12 @@ struct InitTensorOpInterface
if (initTensorOp->getUses().empty())
return success();
- Value alloc = state.createAlloc(rewriter, initTensorOp->getLoc(),
- initTensorOp.result(),
- state.getOptions().createDeallocs);
- replaceOpWithBufferizedValues(rewriter, op, alloc);
+ FailureOr<Value> alloc = state.createAlloc(
+ rewriter, initTensorOp->getLoc(), initTensorOp.result(),
+ state.getOptions().createDeallocs);
+ if (failed(alloc))
+ return failure();
+ replaceOpWithBufferizedValues(rewriter, op, *alloc);
return success();
}
};
@@ -287,6 +291,8 @@ struct TiledLoopOpInterface
if (value.getType().isa<TensorType>()) {
FailureOr<Value> buffer = state.getResultBuffer(
rewriter, tiledLoopOp->getResult(nextResultNum++));
+ if (failed(buffer))
+ return failure();
newOutputs.push_back(*buffer);
newResults.push_back(*buffer);
} else {
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 5983d421aaed..1d62c7880a31 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -295,10 +295,19 @@ struct ForOpInterface
};
// Construct a new scf.for op with memref instead of tensor values.
+ bool resultBufferFailure = false;
SmallVector<Value> initArgs =
convert(forOp.getInitArgs(), [&](Value val, int64_t index) {
- return *state.getResultBuffer(rewriter, forOp->getOpResult(index));
+ FailureOr<Value> resultBuffer =
+ state.getResultBuffer(rewriter, forOp->getOpResult(index));
+ if (failed(resultBuffer)) {
+ resultBufferFailure = true;
+ return Value();
+ }
+ return *resultBuffer;
});
+ if (resultBufferFailure)
+ return failure();
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), initArgs);
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 6b8b8983972a..b6ee0fc63471 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -54,6 +54,8 @@ struct CastOpInterface
// The result buffer still has the old (pre-cast) type.
FailureOr<Value> resultBuffer =
state.getResultBuffer(rewriter, castOp->getResult(0));
+ if (failed(resultBuffer))
+ return failure();
auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
Attribute memorySpace = sourceMemRefType.getMemorySpace();
TensorType resultTensorType =
@@ -149,9 +151,14 @@ struct ExtractSliceOpInterface
// If not inplaceable, alloc.
bool inplace = state.isInPlace(extractSliceOp->getResult(0));
Value alloc;
- if (!inplace)
- alloc = state.createAlloc(rewriter, loc, extractSliceOp.result(),
- state.getOptions().createDeallocs);
+ if (!inplace) {
+ FailureOr<Value> allocOrFailure =
+ state.createAlloc(rewriter, loc, extractSliceOp.result(),
+ state.getOptions().createDeallocs);
+ if (failed(allocOrFailure))
+ return failure();
+ alloc = *allocOrFailure;
+ }
// Bufferize to subview.
auto subviewMemRefType =
@@ -238,6 +245,8 @@ struct InsertOpInterface
auto insertOp = cast<tensor::InsertOp>(op);
FailureOr<Value> destMemref =
state.getResultBuffer(rewriter, insertOp->getOpResult(0));
+ if (failed(destMemref))
+ return failure();
rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
*destMemref, insertOp.indices());
replaceOpWithBufferizedValues(rewriter, op, *destMemref);
@@ -404,6 +413,8 @@ struct InsertSliceOpInterface
// When bufferizing out-of-place, `getResultBuffer` allocates.
FailureOr<Value> dstMemref =
state.getResultBuffer(rewriter, insertSliceOp->getResult(0));
+ if (failed(dstMemref))
+ return failure();
// Take a subview of the dst.
auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index 3c8d6a9c96e5..58013323cb70 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -100,6 +100,8 @@ struct TransferWriteOpInterface
// this point.
FailureOr<Value> resultBuffer =
state.getResultBuffer(rewriter, op->getResult(0));
+ if (failed(resultBuffer))
+ return failure();
rewriter.create<vector::TransferWriteOp>(
writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(),
writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 13e18001d82e..21d7c4e62a45 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -64,9 +64,9 @@ static void applyEnablingTransformations(ModuleOp moduleOp) {
(void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
}
-static Optional<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
- MemRefType type,
- ArrayRef<Value> dynShape) {
+static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
+ MemRefType type,
+ ArrayRef<Value> dynShape) {
Value allocated = b.create<memref::AllocaOp>(
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
return allocated;
More information about the Mlir-commits
mailing list