[Mlir-commits] [mlir] 7c3a810 - [mlir][linalg][bufferize] Put buffer copying in separate function
Matthias Springer
llvmlistbot at llvm.org
Tue Oct 12 17:29:33 PDT 2021
Author: Matthias Springer
Date: 2021-10-13T09:24:56+09:00
New Revision: 7c3a8108b303c4154fb958878256f7c4973b238f
URL: https://github.com/llvm/llvm-project/commit/7c3a8108b303c4154fb958878256f7c4973b238f
DIFF: https://github.com/llvm/llvm-project/commit/7c3a8108b303c4154fb958878256f7c4973b238f.diff
LOG: [mlir][linalg][bufferize] Put buffer copying in separate function
This is to avoid code duplication.
Differential Revision: https://reviews.llvm.org/D110940
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 7cb1b60de3d6..a003ed00e68e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -1405,6 +1405,38 @@ createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
// Bufferization as simple BlockAndValueMapping rewrites.
//===----------------------------------------------------------------------===//
+/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
+/// a new buffer and copy over data from the existing buffer if out-of-place
+/// bufferization is necessary.
+static Value getResultBuffer(OpBuilder &b, OpResult result,
+ const BlockAndValueMapping &bvm,
+ BufferizationAliasInfo &aliasInfo) {
+ OpBuilder::InsertionGuard guard(b);
+ Operation *op = result.getOwner();
+ Optional<OpOperand *> maybeOperand = getAliasingOpOperand(result);
+ assert(maybeOperand && "corresponding OpOperand not found");
+ Value operand = (*maybeOperand)->get();
+ Value operandBuffer = lookup(bvm, operand);
+ assert(operandBuffer && "operand buffer not found");
+
+ // If bufferizing out-of-place, allocate a new buffer.
+ if (getInPlace(result) != InPlaceSpec::True) {
+ Location loc = op->getLoc();
+ // Allocate the result buffer.
+ Value resultBuffer =
+ createNewAllocDeallocPairForShapedValue(b, loc, operand, aliasInfo);
+ if (!isInitTensorOp(operand)) {
+ // Set insertion point now that potential alloc/dealloc are introduced.
+ b.setInsertionPoint(op);
+ b.create<CopyOp>(loc, operandBuffer, resultBuffer);
+ }
+ return resultBuffer;
+ }
+
+ // Bufferizing in-place. No need to allocate a new buffer.
+ return operandBuffer;
+}
+
/// Helper function for LinalgOp bufferization.
/// Examines each result and determines whether it bufferizes inplace on an
/// operand.
@@ -1647,27 +1679,8 @@ static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(castOp);
- // If castOp is not inPlace, allocate a new buffer.
- auto inPlace = getInPlace(castOp->getResult(0));
- Value newBuffer;
- if (inPlace != InPlaceSpec::True) {
- Location loc = castOp.getLoc();
- // Alloc a copy for `writeOp.source()`, it will become the result buffer.
- newBuffer = createNewAllocDeallocPairForShapedValue(b, loc, castOp.source(),
- aliasInfo);
- if (!isInitTensorOp(castOp.source())) {
- // Set insertion point now that potential alloc/dealloc are introduced.
- b.setInsertionPoint(castOp);
- b.create<CopyOp>(loc, lookup(bvm, castOp.source()), newBuffer);
- }
- } else {
- // InPlace write will result in memref.tensor_load(x) which must
- // canonicalize away with one of it uses.
- newBuffer = lookup(bvm, castOp.source());
- assert(newBuffer && "missing buffer");
- }
-
- Type sourceType = newBuffer.getType();
+ Value resultBuffer = getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo);
+ Type sourceType = resultBuffer.getType();
auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
assert(rankedMemRefType || unrankedMemRefType);
@@ -1681,7 +1694,8 @@ static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
: ArrayRef<AffineMap>{};
Type memRefType = getContiguousOrUnrankedMemRefType(
castOp.getResult().getType(), affineMaps, memorySpace);
- Value res = b.create<memref::CastOp>(castOp.getLoc(), memRefType, newBuffer);
+ Value res =
+ b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
map(bvm, castOp.getResult(), res);
return success();
@@ -1731,9 +1745,6 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
- // If inPlace, just forward the buffer.
- // Otherwise alloc and copy.
- Location loc = forOp.getLoc();
for (OpResult opResult : forOp->getResults()) {
if (!opResult.getType().isa<TensorType>())
continue;
@@ -1741,29 +1752,11 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
// alloc an UnrankedMemRefType + its underlying ranked MemRefType.
assert(opResult.getType().isa<RankedTensorType>() &&
"unsupported unranked tensor");
+
+ // TODO: More general: Matching bbArg does not bufferize to a read.
+ Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
+
OpOperand &opOperand = forOp.getOpOperandForResult(opResult);
- Value operand = opOperand.get();
- Value operandBuffer = lookup(bvm, operand);
- Value resultBuffer = operandBuffer;
- if (getInPlace(opResult) != InPlaceSpec::True) {
- resultBuffer =
- createNewAllocDeallocPairForShapedValue(b, loc, operand, aliasInfo);
- // If the tensor comes from either:
- // - linalg.init_tensor
- // - tensor.cast(linalg.init_tensor())
- // Then the value is unitialized and we do not need to copy. This is a
- // pragmatic simplification of "matching bbArg does not bufferize to a
- // read".
- // TODO: "matching bbArg does not bufferize to a read" is a more general
- // check.
- if (!isInitTensorOp(operand)) {
- OpBuilder::InsertionGuard g(b);
- // Set insertion point now that potential alloc/dealloc are introduced.
- // Copy is inserted just before the forOp.
- b.setInsertionPoint(forOp);
- b.create<linalg::CopyOp>(forOp.getLoc(), operandBuffer, resultBuffer);
- }
- }
BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
aliasInfo.createAliasInfoEntry(resultBuffer);
aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
@@ -1880,40 +1873,19 @@ static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
assert(oldOutputTensor.getType().isa<RankedTensorType>() &&
"bufferizable output must be a ranked tensor");
- Value outputBuffer = lookup(bvm, oldOutputTensor);
const OpResult &opResult = tiledLoopOp->getResult(resultIndex);
OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex);
- // If the result is not inplaceable, need to allocate a copy for it.
- if (getInPlace(opResult) != InPlaceSpec::True) {
- auto loc = tiledLoopOp.getLoc();
- Value alloc = createNewAllocDeallocPairForShapedValue(
- b, loc, oldOutputTensor, aliasInfo);
- // If the tensor comes from either:
- // - linalg.init_tensor
- // - tensor.cast(linalg.init_tensor())
- // Then the value is unitialized and we do not need to copy. This is a
- // pragmatic simplification of "matching bbArg does not bufferize to a
- // read".
- // TODO: "matching bbArg does not bufferize to a read" is a more general
- // check.
- if (!isInitTensorOp(oldOutputTensor)) {
- OpBuilder::InsertionGuard g(b);
- // Set insertion point now that potential alloc/dealloc are introduced.
- // Copy is inserted just before the tiledLoopOp.
- b.setInsertionPoint(tiledLoopOp);
- b.create<linalg::CopyOp>(loc, outputBuffer, alloc);
- }
- outputBuffer = alloc;
- }
+ Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
+
// Insert mapping and aliasing info.
- aliasInfo.createAliasInfoEntry(outputBuffer);
- aliasInfo.insertNewBufferEquivalence(opResult, outputBuffer);
- map(bvm, opResult, outputBuffer);
+ aliasInfo.createAliasInfoEntry(resultBuffer);
+ aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer);
+ map(bvm, opResult, resultBuffer);
// Insert new operand and bbArg.
- tiledLoopOp->insertOperands(nextOutputOperandIndex, outputBuffer);
+ tiledLoopOp->insertOperands(nextOutputOperandIndex, resultBuffer);
BlockArgument newBufferBBArg =
- body->insertArgument(nextOutputBBArgIndex, outputBuffer.getType());
+ body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType());
BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex);
// Insert mapping and aliasing info.
aliasInfo.createAliasInfoEntry(newBufferBBArg);
@@ -2043,25 +2015,15 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
LDBG("bufferize: " << *insertSliceOp << '\n');
Location loc = insertSliceOp.getLoc();
- Value dstMemref = lookup(bvm, insertSliceOp.dest());
- if (!dstMemref)
- return failure();
- auto inPlace = getInPlace(insertSliceOp->getResult(0));
- if (inPlace != InPlaceSpec::True) {
- // Since insert_slice arise from tiling and introducing loops, this
- // case is generally a deal breaker. When used with loops, this ends up
- // cloning the whole tensor on every single iteration and is a symptom
- // of a catastrophically bad scheduling decision.
- // TODO: be very loud about it or even consider failing the pass.
- // Alloc a copy for `insertSliceOp.dest()`, it will become the result
- // buffer.
- Value newDstMemref = createNewAllocDeallocPairForShapedValue(
- b, loc, insertSliceOp.dest(), aliasInfo);
- // Set insertion point now that potential alloc/dealloc are introduced.
- b.setInsertionPoint(insertSliceOp);
- b.create<CopyOp>(insertSliceOp.getLoc(), dstMemref, newDstMemref);
- dstMemref = newDstMemref;
- }
+ // Since insert_slice arise from tiling and introducing loops, this
+ // case is generally a deal breaker. When used with loops, this ends up
+ // cloning the whole tensor on every single iteration and is a symptom
+ // of a catastrophically bad scheduling decision.
+ // TODO: be very loud about it or even consider failing the pass.
+ // Alloc a copy for `insertSliceOp.dest()`, it will become the result
+ // buffer.
+ Value dstMemref =
+ getResultBuffer(b, insertSliceOp->getResult(0), bvm, aliasInfo);
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
Value srcMemref = lookup(bvm, insertSliceOp.source());
@@ -2079,6 +2041,7 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
// slice is computed out of place into the inplace full tensor.
// - The result is not inplace. This is the case where the whole tensor is
// cloned and the clone needs to be updated.
+ auto inPlace = getInPlace(insertSliceOp->getResult(0));
if (!aliasInfo.isSourceEquivalentToAMatchingInplaceExtractSliceOp(
insertSliceOp) ||
inPlace != InPlaceSpec::True) {
@@ -2117,38 +2080,16 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
return success();
}
- auto inPlace = getInPlace(op->getResult(0));
- auto writeOp = cast<vector::TransferWriteOp>(op.getOperation());
-
- // If transfer_write is not inPlace, allocate a new buffer.
- Value newInputBuffer;
- Location loc = op.getLoc();
- if (inPlace != InPlaceSpec::True) {
- // Alloc a copy for `writeOp.source()`, it will become the result buffer.
- newInputBuffer = createNewAllocDeallocPairForShapedValue(
- b, loc, writeOp.source(), aliasInfo);
- Value v = lookup(bvm, writeOp.source());
- if (!isInitTensorOp(writeOp.source())) {
- // Set insertion point now that potential alloc/dealloc are introduced.
- b.setInsertionPoint(op);
- b.create<CopyOp>(loc, v, newInputBuffer);
- }
- } else {
- // InPlace write will result in memref.tensor_load(x) which must
- // canonicalize away with one of it uses.
- newInputBuffer = lookup(bvm, writeOp.source());
- assert(newInputBuffer && "missing buffer");
- }
-
// Create a new transfer_write on buffer that doesn't have a return value.
// Leave the previous transfer_write to dead code as it still has uses at
// this point.
+ auto writeOp = cast<vector::TransferWriteOp>(op.getOperation());
+ Value resultBuffer = getResultBuffer(b, op->getResult(0), bvm, aliasInfo);
b.create<vector::TransferWriteOp>(
- loc, writeOp.vector(), newInputBuffer, writeOp.indices(),
+ op.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
writeOp.permutation_map(),
writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr());
-
- map(bvm, op->getResult(0), newInputBuffer);
+ map(bvm, op->getResult(0), resultBuffer);
return success();
}
More information about the Mlir-commits
mailing list