[Mlir-commits] [mlir] 6db2007 - [mlir][linalg][bufferize][NFC] Use same OpBuilder throughout bufferization
Matthias Springer
llvmlistbot at llvm.org
Fri Dec 3 17:03:18 PST 2021
Author: Matthias Springer
Date: 2021-12-04T09:57:26+09:00
New Revision: 6db200736c51a61834fd2e192d8a5fd71e0874b4
URL: https://github.com/llvm/llvm-project/commit/6db200736c51a61834fd2e192d8a5fd71e0874b4
DIFF: https://github.com/llvm/llvm-project/commit/6db200736c51a61834fd2e192d8a5fd71e0874b4.diff
LOG: [mlir][linalg][bufferize][NFC] Use same OpBuilder throughout bufferization
Also set insertion point right before calling `bufferize`. No need to put an InsertionGuard anymore.
Differential Revision: https://reviews.llvm.org/D114928
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index ca7711dd032f4..1b3b3ff2f12e1 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -297,7 +297,8 @@ struct DialectBufferizationState {
/// the results of the analysis.
struct BufferizationState {
BufferizationState(ModuleOp moduleOp, const BufferizationOptions &options)
- : aliasInfo(moduleOp), options(options) {}
+ : aliasInfo(moduleOp), options(options),
+ builder(moduleOp->getContext()) {}
// BufferizationState should be passed as a reference.
BufferizationState(const BufferizationState &) = delete;
@@ -321,6 +322,11 @@ struct BufferizationState {
/// Return `true` if the given value is mapped.
bool isMapped(Value value) const;
+ /// Return the result buffer (memref) for a given OpResult (tensor). Allocate
+ /// a new buffer and copy over data from the existing buffer if out-of-place
+ /// bufferization is necessary.
+ Value getResultBuffer(OpResult result);
+
/// Mark `op` as obsolete, so that it is deleted after bufferization.
void markOpObsolete(Operation *op);
@@ -349,12 +355,10 @@ struct BufferizationState {
/// A reference to current bufferization options.
const BufferizationOptions &options;
-};
-/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
-/// a new buffer and copy over data from the existing buffer if out-of-place
-/// bufferization is necessary.
-Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state);
+ /// The OpBuilder used during bufferization.
+ OpBuilder builder;
+};
/// Bufferize all ops in the given region.
LogicalResult bufferize(Region *region, BufferizationState &state);
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
index 8fb8f919c5ea6..fa0c96275daf0 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
@@ -26,21 +26,14 @@ struct ConstantOpInterface
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto constantOp = cast<arith::ConstantOp>(op);
- if (!constantOp.getResult().getType().isa<TensorType>())
- return success();
assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
"not a constant ranked tensor");
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
- if (!moduleOp) {
+ if (!moduleOp)
return constantOp.emitError(
"cannot bufferize constants not within builtin.module op");
- }
- GlobalCreator globalCreator(moduleOp);
-
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(constantOp);
+ GlobalCreator globalCreator(moduleOp);
auto globalMemref = globalCreator.getGlobalFor(constantOp);
Value memref = b.create<memref::GetGlobalOp>(
constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 4ed9ec9fee10a..4348fe4d5ad28 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -372,15 +372,15 @@ Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite(
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
-Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
- OpBuilder &b, OpResult result, BufferizationState &state) {
- OpBuilder::InsertionGuard guard(b);
+Value mlir::linalg::comprehensive_bufferize::BufferizationState::
+ getResultBuffer(OpResult result) {
+ OpBuilder::InsertionGuard guard(builder);
Operation *op = result.getOwner();
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
OpOperand *opOperand = aliasingOperands.front();
Value operand = opOperand->get();
- Value operandBuffer = state.lookupBuffer(operand);
+ Value operandBuffer = lookupBuffer(operand);
// Make sure that all OpOperands are the same buffer. If this is not the case,
// we would have to materialize a memref value.
// TODO: Should be looking for checking for "equivalent buffers" instead of
@@ -388,14 +388,14 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
// set up yet.
if (aliasingOperands.size() > 1 &&
!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
- return state.lookupBuffer(o->get()) == operandBuffer;
+ return lookupBuffer(o->get()) == operandBuffer;
})) {
op->emitError("result buffer is ambiguous");
return Value();
}
// If bufferizing out-of-place, allocate a new buffer.
- if (!state.aliasInfo.isInPlace(result)) {
+ if (!aliasInfo.isInPlace(result)) {
// Ops with multiple aliasing operands can currently not bufferize
// out-of-place.
assert(
@@ -404,9 +404,9 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
Location loc = op->getLoc();
// Move insertion point right after `operandBuffer`. That is where the
// allocation should be inserted (in the absence of allocation hoisting).
- setInsertionPointAfter(b, operandBuffer);
+ setInsertionPointAfter(builder, operandBuffer);
// Allocate the result buffer.
- Value resultBuffer = state.createAllocDeallocFn(b, loc, operandBuffer);
+ Value resultBuffer = createAllocDeallocFn(builder, loc, operandBuffer);
bool skipCopy = false;
// Do not copy if the last preceding write of `operand` is an op that does
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@@ -427,9 +427,9 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
skipCopy = true;
if (!skipCopy) {
// The copy happens right before the op that is bufferized.
- b.setInsertionPoint(op);
- state.options.allocationFns->memCpyFn(b, loc, operandBuffer,
- resultBuffer);
+ builder.setInsertionPoint(op);
+ options.allocationFns->memCpyFn(builder, loc, operandBuffer,
+ resultBuffer);
}
return resultBuffer;
}
@@ -459,7 +459,7 @@ mlir::linalg::comprehensive_bufferize::bufferize(Block *block,
LogicalResult
mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
BufferizationState &state) {
- OpBuilder b(op->getContext());
+ OpBuilder &b = state.builder;
// Check if op has tensor results or operands.
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 5f62c97ff2651..3f14975386a51 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -38,7 +38,7 @@ allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
OpResult opResult = cast<BufferizableOpInterface>(op.getOperation())
.getAliasingOpResult(*opOperand);
assert(opResult && "could not find correspond OpResult");
- Value resultBuffer = getResultBuffer(b, opResult, state);
+ Value resultBuffer = state.getResultBuffer(opResult);
if (!resultBuffer)
return failure();
resultBuffers.push_back(resultBuffer);
@@ -158,10 +158,6 @@ struct InitTensorOpInterface
if (initTensorOp->getUses().empty())
return success();
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(initTensorOp);
-
Value alloc = state.createAllocDeallocFn(b, initTensorOp->getLoc(),
initTensorOp.result());
state.mapBuffer(initTensorOp.result(), alloc);
@@ -250,7 +246,7 @@ struct TiledLoopOpInterface
const OpResult &opResult = tiledLoopOp->getResult(resultIndex);
OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex);
- Value resultBuffer = getResultBuffer(b, opResult, state);
+ Value resultBuffer = state.getResultBuffer(opResult);
if (!resultBuffer)
return failure();
@@ -350,11 +346,6 @@ struct YieldOpInterface
BufferizationState &state) const {
auto yieldOp = cast<linalg::YieldOp>(op);
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- // Cannot create IR past a yieldOp.
- b.setInsertionPoint(yieldOp);
-
// No tensors -> success.
if (!llvm::any_of(yieldOp.getOperandTypes(),
[](Type t) { return t.isa<TensorType>(); }))
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 94b14bc5622e0..30f51d5d2ca38 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -469,10 +469,6 @@ struct CallOpInterface
"expected Callop to a FuncOp");
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(callOp);
-
// 1. Filter return types:
// - if the callee is bodiless / external, we cannot inspect it and we
// cannot assume anything. We can just assert that it does not return a
@@ -600,14 +596,9 @@ struct ReturnOpInterface
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto returnOp = cast<ReturnOp>(op);
-
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- // Cannot insert after returnOp.
- b.setInsertionPoint(returnOp);
-
assert(isa<FuncOp>(returnOp->getParentOp()) &&
"only support FuncOp parent for ReturnOp");
+
for (OpOperand &operand : returnOp->getOpOperands()) {
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
if (!tensorType)
@@ -628,9 +619,6 @@ struct FuncOpInterface
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto funcOp = cast<FuncOp>(op);
-
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(&funcOp.body().front());
// Create BufferCastOps for function args.
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 8632c65ca7c66..5dc434335e1ef 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -138,7 +138,7 @@ struct IfOpInterface
assert(opResult.getType().isa<RankedTensorType>() &&
"unsupported unranked tensor");
- Value resultBuffer = getResultBuffer(b, opResult, state);
+ Value resultBuffer = state.getResultBuffer(opResult);
if (!resultBuffer)
return failure();
@@ -204,7 +204,7 @@ struct ForOpInterface
"unsupported unranked tensor");
// TODO: More general: Matching bbArg does not bufferize to a read.
- Value resultBuffer = getResultBuffer(b, opResult, state);
+ Value resultBuffer = state.getResultBuffer(opResult);
if (!resultBuffer)
return failure();
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index d944d81267dc8..11333807dd7a8 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -60,11 +60,7 @@ struct CastOpInterface
BufferizationState &state) const {
auto castOp = cast<tensor::CastOp>(op);
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(castOp);
-
- Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state);
+ Value resultBuffer = state.getResultBuffer(castOp->getResult(0));
if (!resultBuffer)
return failure();
Type sourceType = resultBuffer.getType();
@@ -107,11 +103,6 @@ struct DimOpInterface
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto dimOp = cast<tensor::DimOp>(op);
-
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(dimOp);
-
if (dimOp.source().getType().isa<RankedTensorType>()) {
Value v = state.lookupBuffer(dimOp.source());
dimOp.result().replaceAllUsesWith(
@@ -145,11 +136,6 @@ struct ExtractSliceOpInterface
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
-
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(extractSliceOp);
-
Location loc = extractSliceOp.getLoc();
Value srcMemref = state.lookupBuffer(extractSliceOp.source());
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
@@ -207,11 +193,6 @@ struct ExtractOpInterface
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto extractOp = cast<tensor::ExtractOp>(op);
-
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(extractOp);
-
Location loc = extractOp.getLoc();
Value srcMemref = state.lookupBuffer(extractOp.tensor());
Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
@@ -245,13 +226,8 @@ struct InsertOpInterface
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto insertOp = cast<tensor::InsertOp>(op);
-
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(insertOp);
-
Location loc = insertOp.getLoc();
- Value destMemref = getResultBuffer(b, insertOp->getOpResult(0), state);
+ Value destMemref = state.getResultBuffer(insertOp->getOpResult(0));
b.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
insertOp.indices());
state.mapBuffer(insertOp, destMemref);
@@ -419,15 +395,11 @@ struct InsertSliceOpInterface
// catastrophically bad scheduling decision.
// TODO: be very loud about it or even consider failing the pass.
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
- TensorBufferizationState &tensorState = getTensorBufferizationState(state);
-
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(insertSliceOp);
Location loc = insertSliceOp.getLoc();
+ TensorBufferizationState &tensorState = getTensorBufferizationState(state);
// When bufferizing out-of-place, `getResultBuffer` allocates.
- Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state);
+ Value dstMemref = state.getResultBuffer(insertSliceOp->getResult(0));
if (!dstMemref)
return failure();
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index d582b88f1eb0a..3fafa75aa79bf 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -39,14 +39,10 @@ struct TransferReadOpInterface
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto transferReadOp = cast<vector::TransferReadOp>(op);
-
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(op);
-
- // TransferReadOp always reads from the bufferized op.source().
assert(transferReadOp.getShapedType().isa<TensorType>() &&
"only tensor types expected");
+
+ // TransferReadOp always reads from the bufferized op.source().
Value v = state.lookupBuffer(transferReadOp.source());
transferReadOp.sourceMutable().assign(v);
return success();
@@ -81,17 +77,13 @@ struct TransferWriteOpInterface
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto writeOp = cast<vector::TransferWriteOp>(op);
-
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(op);
+ assert(writeOp.getShapedType().isa<TensorType>() &&
+ "only tensor types expected");
// Create a new transfer_write on buffer that doesn't have a return value.
// Leave the previous transfer_write to dead code as it still has uses at
// this point.
- assert(writeOp.getShapedType().isa<TensorType>() &&
- "only tensor types expected");
- Value resultBuffer = getResultBuffer(b, op->getResult(0), state);
+ Value resultBuffer = state.getResultBuffer(op->getResult(0));
if (!resultBuffer)
return failure();
b.create<vector::TransferWriteOp>(
More information about the Mlir-commits
mailing list