[Mlir-commits] [mlir] 65ef43e - [mlir][linalg][bufferize][NFC] Check return value of getResultBuffer
Matthias Springer
llvmlistbot at llvm.org
Thu Oct 21 01:27:56 PDT 2021
Author: Matthias Springer
Date: 2021-10-21T17:24:17+09:00
New Revision: 65ef43e288ad1e9fa7a01d2c09a13727f568b870
URL: https://github.com/llvm/llvm-project/commit/65ef43e288ad1e9fa7a01d2c09a13727f568b870
DIFF: https://github.com/llvm/llvm-project/commit/65ef43e288ad1e9fa7a01d2c09a13727f568b870.diff
LOG: [mlir][linalg][bufferize][NFC] Check return value of getResultBuffer
In a subsequent commit, getResultBuffer can return a "null" Value. This is the case when the returned buffer from an scf.if is not unique.
This commit is in preparation for scf.if support to keep the next commit smaller.
Differential Revision: https://reviews.llvm.org/D111927
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index e0e4105870b4..be18cd135a47 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -1422,10 +1422,11 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
/// Helper function for LinalgOp bufferization.
/// When allocating a new buffer, analyze whether `op` wants to read form that
/// buffer. Only in that case, a copy of the result buffer may be needed.
-static void allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
- SmallVectorImpl<Value> &resultBuffers,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo) {
+static LogicalResult
+allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
+ SmallVectorImpl<Value> &resultBuffers,
+ BlockAndValueMapping &bvm,
+ BufferizationAliasInfo &aliasInfo) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
@@ -1437,11 +1438,15 @@ static void allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
assert(opResult && "could not find correspond OpResult");
bool skipCopy = !op.payloadUsesValueFromOperand(opOperand);
Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo, skipCopy);
+ if (!resultBuffer)
+ return failure();
resultBuffers.push_back(resultBuffer);
}
if (op->getNumResults())
map(bvm, op->getResults(), resultBuffers);
+
+ return success();
}
/// Generic conversion for any LinalgOp on tensors.
@@ -1469,7 +1474,9 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
}
SmallVector<Value> newOutputBuffers;
// Try to allocate new buffers depending on op's inplace semantics.
- allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm, aliasInfo);
+ if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm,
+ aliasInfo)))
+ return failure();
// Clone the newly bufferized op.
SmallVector<Value> newOperands = newInputBuffers;
@@ -1638,6 +1645,8 @@ static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
b.setInsertionPoint(castOp);
Value resultBuffer = getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo);
+ if (!resultBuffer)
+ return failure();
Type sourceType = resultBuffer.getType();
auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
@@ -1713,6 +1722,8 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
// TODO: More general: Matching bbArg does not bufferize to a read.
Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
+ if (!resultBuffer)
+ return failure();
OpOperand &opOperand = forOp.getOpOperandForResult(opResult);
BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
@@ -1834,6 +1845,8 @@ static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
const OpResult &opResult = tiledLoopOp->getResult(resultIndex);
OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex);
Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
+ if (!resultBuffer)
+ return failure();
// Insert mapping and aliasing info.
aliasInfo.createAliasInfoEntry(resultBuffer);
@@ -1982,6 +1995,9 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
// buffer.
Value dstMemref =
getResultBuffer(b, insertSliceOp->getResult(0), bvm, aliasInfo);
+ if (!dstMemref)
+ return failure();
+
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
Value srcMemref = lookup(bvm, insertSliceOp.source());
@@ -2044,6 +2060,8 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
// this point.
auto writeOp = cast<vector::TransferWriteOp>(op.getOperation());
Value resultBuffer = getResultBuffer(b, op->getResult(0), bvm, aliasInfo);
+ if (!resultBuffer)
+ return failure();
b.create<vector::TransferWriteOp>(
op.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
writeOp.permutation_map(),
More information about the Mlir-commits
mailing list