[Mlir-commits] [mlir] 17194ca - [mlir][linalg][bufferize][NFC] Clean up tensor op bufferization
Matthias Springer
llvmlistbot at llvm.org
Mon Nov 15 18:18:13 PST 2021
Author: Matthias Springer
Date: 2021-11-16T11:17:42+09:00
New Revision: 17194ca96ab5a46ee90c656f7654d3f15f5d46c6
URL: https://github.com/llvm/llvm-project/commit/17194ca96ab5a46ee90c656f7654d3f15f5d46c6
DIFF: https://github.com/llvm/llvm-project/commit/17194ca96ab5a46ee90c656f7654d3f15f5d46c6.diff
LOG: [mlir][linalg][bufferize][NFC] Clean up tensor op bufferization
Differential Revision: https://reviews.llvm.org/D113730
Added:
Modified:
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 6482d740f1515..697b894f89908 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -2244,23 +2244,22 @@ struct ExtractSliceOpInterface
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+ LDBG("bufferize: " << *extractSliceOp << '\n');
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(extractSliceOp);
- LDBG("bufferize: " << *extractSliceOp << '\n');
-
Location loc = extractSliceOp.getLoc();
- // Bail if source was not bufferized.
Value srcMemref = state.lookupBuffer(extractSliceOp.source());
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
auto dstTensorType =
extractSliceOp.result().getType().cast<RankedTensorType>();
// If not inplaceable, alloc.
+ bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0));
Value alloc;
- if (!state.aliasInfo.isInPlace(extractSliceOp->getResult(0)))
+ if (!inplace)
alloc = createNewAllocDeallocPairForShapedValue(
b, loc, extractSliceOp.result(), state);
@@ -2278,7 +2277,7 @@ struct ExtractSliceOpInterface
state.aliasInfo.insertNewBufferAlias(subView, srcMemref);
/// If not inplaceable, copy.
- if (alloc) {
+ if (!inplace) {
// Do not copy if the copied data is never read.
if (isValueRead(extractSliceOp.result()))
state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView,
@@ -2374,34 +2373,23 @@ struct InsertSliceOpInterface
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
+ // insert_slice ops arise from tiling and bufferizing them out-of-place is
+ // generally a deal breaker. When used with loops, this ends up cloning the
+ // whole tensor on every single iteration and is a symptom of a
+ // catastrophically bad scheduling decision.
+ // TODO: be very loud about it or even consider failing the pass.
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+ LDBG("bufferize: " << *insertSliceOp << '\n');
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(insertSliceOp);
-
- LDBG("bufferize: " << *insertSliceOp << '\n');
-
Location loc = insertSliceOp.getLoc();
- // Since insert_slice arise from tiling and introducing loops, this
- // case is generally a deal breaker. When used with loops, this ends up
- // cloning the whole tensor on every single iteration and is a symptom
- // of a catastrophically bad scheduling decision.
- // TODO: be very loud about it or even consider failing the pass.
- // Alloc a copy for `insertSliceOp.dest()`, it will become the result
- // buffer.
+
+ // When bufferizing out-of-place, `getResultBuffer` allocates.
Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state);
if (!dstMemref)
return failure();
- auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
-
- Value srcMemref = state.lookupBuffer(insertSliceOp.source());
- auto subviewMemRefType =
- memref::SubViewOp::inferRankReducedResultType(
- insertSliceOp.getSourceType().getRank(), dstMemrefType,
- insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
- insertSliceOp.getMixedStrides())
- .cast<MemRefType>();
// A copy of the source buffer is needed if either:
// - The producer of `source` is not inplace. This is the case where a
@@ -2409,23 +2397,32 @@ struct InsertSliceOpInterface
// - The result is not inplace. This is the case where the whole tensor is
// cloned and the clone needs to be updated.
// TODO: Is this necessary?
- if (!isSourceEquivalentToAMatchingInplaceExtractSliceOp(state.aliasInfo,
- insertSliceOp) ||
- !state.aliasInfo.isInPlace(insertSliceOp->getResult(0))) {
+ bool needCopy = !isSourceEquivalentToAMatchingInplaceExtractSliceOp(
+ state.aliasInfo, insertSliceOp) ||
+ !state.aliasInfo.isInPlace(insertSliceOp->getResult(0));
+ if (needCopy) {
LDBG("insert_slice needs extra source copy: " << insertSliceOp.source()
<< " -> copy\n");
// Take a subview of the dst.
+ auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
+ auto subviewMemRefType =
+ memref::SubViewOp::inferRankReducedResultType(
+ insertSliceOp.getSourceType().getRank(), dstMemrefType,
+ insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
+ insertSliceOp.getMixedStrides())
+ .cast<MemRefType>();
Value subView = b.create<memref::SubViewOp>(
loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
// Insert new alias.
state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
+ // Copy tensor.
+ Value srcMemref = state.lookupBuffer(insertSliceOp.source());
state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref,
subView);
}
state.mapBuffer(insertSliceOp.result(), dstMemref);
-
return success();
}
};
More information about the Mlir-commits
mailing list