[Mlir-commits] [mlir] dd65f42 - [mlir][Linalg] NFC - More gracefully degrade lookup into failure during comprehensive bufferization (4/n)
Nicolas Vasilache
llvmlistbot at llvm.org
Fri May 14 15:12:52 PDT 2021
Author: Nicolas Vasilache
Date: 2021-05-14T22:12:23Z
New Revision: dd65f420cd2b983ea1e71ed685c811f00110bafb
URL: https://github.com/llvm/llvm-project/commit/dd65f420cd2b983ea1e71ed685c811f00110bafb
DIFF: https://github.com/llvm/llvm-project/commit/dd65f420cd2b983ea1e71ed685c811f00110bafb.diff
LOG: [mlir][Linalg] NFC - More gracefully degrade lookup into failure during comprehensive bufferization (4/n)
Differential revsion: https://reviews.llvm.org/D102420
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 0f271c2f27983..4ad8095505025 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -285,7 +285,7 @@ static Value lookup(BlockAndValueMapping &bvm, Value key) {
key.getDefiningOp()->getParentOfType<FuncOp>()->dump();
}
llvm::errs() << "NO VALUE FOR KEY: " << key << "\n";
- abort();
+ return Value();
}
return bvm.lookup(key);
}
@@ -595,9 +595,10 @@ static void destructiveUpdateAnalysis(Block *block,
/// the Linalg op. If the tensor is an "init" tensor (i.e. its value is
/// actually used in the payload region), we additionally copy the original
/// value into the newly allocated buffer.
-static void allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
- SmallVectorImpl<Value> &resultBuffers,
- BlockAndValueMapping &bvm) {
+static LogicalResult
+allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
+ SmallVectorImpl<Value> &resultBuffers,
+ BlockAndValueMapping &bvm) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@@ -618,7 +619,10 @@ static void allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
// results.
OpResult tiedResult = getMatchingOpResult(op, opOperand);
if (getInPlace(tiedResult) == InPlaceSpec::True) {
- resultBuffers.push_back(lookup(bvm, output));
+ Value v = lookup(bvm, output);
+ if (!v)
+ return failure();
+ resultBuffers.push_back(v);
continue;
}
@@ -628,11 +632,17 @@ static void allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
resultBuffers.push_back(alloc);
// Additionally, if the output buffer is used, clone its value for now.
- if (op.payloadUsesValueFromOpOperand(&opOperand))
- b.create<CopyOp>(loc, lookup(bvm, output), alloc);
+ if (op.payloadUsesValueFromOpOperand(&opOperand)) {
+ Value v = lookup(bvm, output);
+ if (!v)
+ return failure();
+ b.create<CopyOp>(loc, v, alloc);
+ }
}
if (op->getNumResults())
map(bvm, op->getResults(), resultBuffers);
+
+ return success();
}
static void finalizeBufferAllocation(OpBuilder &b, LinalgOp op,
@@ -662,16 +672,21 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
if (op.hasBufferSemantics())
return failure();
- LLVM_DEBUG(DBGS() << "convert: " << *op << "\n");
+ LLVM_DEBUG(DBGS() << "bufferize: " << *op << "\n");
b.setInsertionPoint(op);
Location loc = op.getLoc();
SmallVector<Value, 2> newInputBuffers;
newInputBuffers.reserve(op.getNumInputs());
- for (Value v : op.getInputs())
- newInputBuffers.push_back(lookup(bvm, v));
+ for (Value in : op.getInputs()) {
+ Value v = lookup(bvm, in);
+ if (!v)
+ return failure();
+ newInputBuffers.push_back(v);
+ }
SmallVector<Value, 2> newOutputBuffers;
- allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm);
+ if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm)))
+ return failure();
finalizeBufferAllocation(b, op, newInputBuffers, newOutputBuffers, bvm);
return success();
}
@@ -680,8 +695,12 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
/// behind that will get DCE'd.
static LogicalResult bufferize(OpBuilder &b, memref::DimOp dimOp,
BlockAndValueMapping &bvm) {
- if (dimOp.memrefOrTensor().getType().isa<RankedTensorType>())
- dimOp.memrefOrTensorMutable().assign(lookup(bvm, dimOp.memrefOrTensor()));
+ if (dimOp.memrefOrTensor().getType().isa<RankedTensorType>()) {
+ Value v = lookup(bvm, dimOp.memrefOrTensor());
+ if (!v)
+ return failure();
+ dimOp.memrefOrTensorMutable().assign(v);
+ }
return success();
}
@@ -721,8 +740,10 @@ static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
if (!tensorType)
continue;
- operand.set(b.create<memref::TensorLoadOp>(returnOp.getLoc(),
- lookup(bvm, operand.get())));
+ Value v = lookup(bvm, operand.get());
+ if (!v)
+ return failure();
+ operand.set(b.create<memref::TensorLoadOp>(returnOp.getLoc(), v));
}
return success();
}
@@ -738,6 +759,8 @@ static LogicalResult bufferize(OpBuilder &b,
Location loc = subTensorInsertOp.getLoc();
Value dstMemref = lookup(bvm, subTensorInsertOp.dest());
+ if (!dstMemref)
+ return failure();
auto inPlace = getInPlace(subTensorInsertOp->getResult(0));
if (inPlace != InPlaceSpec::True) {
// Since subtensor_insert arise from tiling and introducing loops, this case
@@ -754,6 +777,8 @@ static LogicalResult bufferize(OpBuilder &b,
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
Value srcMemref = lookup(bvm, subTensorInsertOp.source());
+ if (!srcMemref)
+ return failure();
auto subviewMemRefType =
memref::SubViewOp::inferRankReducedResultType(
subTensorInsertOp.getSourceType().getRank(), dstMemrefType,
@@ -798,11 +823,14 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
if (op.getShapedType().isa<MemRefType>())
return failure();
- LLVM_DEBUG(DBGS() << "convert: " << *op << "\n");
+ LLVM_DEBUG(DBGS() << "bufferize: " << *op << "\n");
/// transfer_read from buffer always reads from the bufferized op.source().
if (auto readOp = dyn_cast<vector::TransferReadOp>(op.getOperation())) {
- readOp.sourceMutable().assign(lookup(bvm, op.source()));
+ Value v = lookup(bvm, op.source());
+ if (!v)
+ return failure();
+ readOp.sourceMutable().assign(v);
return success();
}
@@ -820,6 +848,8 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
// InPlace write will result in memref.tensor_load(x) which must
// canonicalize away with one of it uses.
newInputBuffer = lookup(bvm, writeOp.source());
+ if (!newInputBuffer)
+ return failure();
}
// Create a new transfer_write on buffer that doesn't have a return value.
More information about the Mlir-commits
mailing list