[Mlir-commits] [mlir] 6c6bba7 - [mlir][linalg][bufferize][NFC] Use RewriterBase instead of OpBuilder
Matthias Springer
llvmlistbot at llvm.org
Wed Jan 5 04:05:57 PST 2022
Author: Matthias Springer
Date: 2022-01-05T21:05:42+09:00
New Revision: 6c6bba743674c4f72dfd1adb89d44475a9b3cb88
URL: https://github.com/llvm/llvm-project/commit/6c6bba743674c4f72dfd1adb89d44475a9b3cb88
DIFF: https://github.com/llvm/llvm-project/commit/6c6bba743674c4f72dfd1adb89d44475a9b3cb88.diff
LOG: [mlir][linalg][bufferize][NFC] Use RewriterBase instead of OpBuilder
This is in preparation of unifying core bufferization and Comprehensive Bufferize.
Differential Revision: https://reviews.llvm.org/D116102
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index bda6c25b2877..cfafc6b33bb7 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -14,6 +14,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/SetVector.h"
@@ -296,7 +297,8 @@ struct DialectBufferizationState {
/// * `replaceOp` replaces an op with new values.
class BufferizationState {
public:
- BufferizationState(Operation *op, const BufferizationOptions &options);
+ BufferizationState(Operation *op, const BufferizationOptions &options,
+ RewriterBase &rewriter);
// BufferizationState should be passed as a reference.
BufferizationState(const BufferizationState &) = delete;
@@ -387,9 +389,10 @@ class BufferizationState {
/// Replace an op with a new op. Tensor OpResults must be replaced with memref
/// values.
template <typename OpTy, typename... Args>
- OpTy replaceOpWithNewOp(OpBuilder &b, Operation *op, Args &&...args) {
+ OpTy replaceOpWithNewOp(RewriterBase &rewriter, Operation *op,
+ Args &&...args) {
Operation *newOp =
- b.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
+ rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
replaceOp(op, newOp->getResults());
return cast<OpTy>(newOp);
}
@@ -417,8 +420,8 @@ class BufferizationState {
/// Return a reference to the BufferizationOptions.
const BufferizationOptions &getOptions() const { return options; }
- /// Return a reference to the OpBuilder.
- OpBuilder &getBuilder() { return builder; }
+ /// Return a reference to the rewriter.
+ RewriterBase &getRewriter() { return rewriter; }
private:
friend LogicalResult
@@ -440,7 +443,7 @@ class BufferizationState {
const BufferizationOptions &options;
/// The OpBuilder used during bufferization.
- OpBuilder builder;
+ RewriterBase &rewriter;
};
/// Bufferize all ops in the given region.
@@ -523,7 +526,7 @@ struct AllocationHoistingBarrierOnly
return false;
}
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
if (any_of(op->getOperandTypes(), isaTensor) ||
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
index df9090972bed..56c6b848c5f3 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
@@ -209,7 +209,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
}],
/*retType=*/"LogicalResult",
/*methodName=*/"bufferize",
- /*args=*/(ins "OpBuilder &":$b,
+ /*args=*/(ins "RewriterBase &":$rewriter,
"BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
index e370d3f43042..e8d0fa984bb0 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
@@ -23,7 +23,7 @@ namespace arith_ext {
struct ConstantOpInterface
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
arith::ConstantOp> {
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto constantOp = cast<arith::ConstantOp>(op);
assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
@@ -35,8 +35,8 @@ struct ConstantOpInterface
GlobalCreator globalCreator(moduleOp);
auto globalMemref = globalCreator.getGlobalFor(constantOp);
- state.replaceOpWithNewOp<memref::GetGlobalOp>(b, op, globalMemref.type(),
- globalMemref.getName());
+ state.replaceOpWithNewOp<memref::GetGlobalOp>(
+ rewriter, op, globalMemref.type(), globalMemref.getName());
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index f7d22251eadb..a639711196b4 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -333,8 +333,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
}
mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
- Operation *op, const BufferizationOptions &options)
- : aliasInfo(op), options(options), builder(op->getContext()) {
+ Operation *op, const BufferizationOptions &options, RewriterBase &rewriter)
+ : aliasInfo(op), options(options), rewriter(rewriter) {
// Set up alias sets for OpResults that must bufferize in-place. This should
// be done before making any other bufferization decisions.
op->walk([&](BufferizableOpInterface bufferizableOp) {
@@ -361,7 +361,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
/// bufferization is necessary.
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
getResultBuffer(OpResult result) {
- OpBuilder::InsertionGuard guard(builder);
+ OpBuilder::InsertionGuard guard(rewriter);
Operation *op = result.getOwner();
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
@@ -391,9 +391,9 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
Location loc = op->getLoc();
// Move insertion point right after `operandBuffer`. That is where the
// allocation should be inserted (in the absence of allocation hoisting).
- setInsertionPointAfter(builder, operandBuffer);
+ setInsertionPointAfter(rewriter, operandBuffer);
// Allocate the result buffer.
- Value resultBuffer = createAllocDeallocPair(builder, loc, operandBuffer);
+ Value resultBuffer = createAllocDeallocPair(rewriter, loc, operandBuffer);
bool skipCopy = false;
// Do not copy if the last preceding write of `operand` is an op that does
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@@ -413,8 +413,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
skipCopy = true;
if (!skipCopy) {
// The copy happens right before the op that is bufferized.
- builder.setInsertionPoint(op);
- createMemCpy(builder, loc, operandBuffer, resultBuffer);
+ rewriter.setInsertionPoint(op);
+ createMemCpy(rewriter, loc, operandBuffer, resultBuffer);
}
return resultBuffer;
}
@@ -425,8 +425,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
Operation *op, ValueRange values) {
- OpBuilder &b = getBuilder();
- OpBuilder::InsertionGuard g(b);
+ OpBuilder::InsertionGuard g(rewriter);
// Replace all OpResults with the given values.
for (OpResult opResult : op->getOpResults()) {
@@ -444,14 +443,14 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
// The existing uses of the OpResult still expect a tensor. Insert a
// ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
// loose all of its users and eventually DCE away.
- setInsertionPointAfter(b, replacement);
- replacement = b.create<bufferization::ToTensorOp>(replacement.getLoc(),
- replacement);
+ setInsertionPointAfter(rewriter, replacement);
+ replacement = rewriter.create<bufferization::ToTensorOp>(
+ replacement.getLoc(), replacement);
}
opResult.replaceAllUsesWith(replacement);
}
- op->erase();
+ rewriter.eraseOp(op);
}
LogicalResult
@@ -481,7 +480,7 @@ mlir::linalg::comprehensive_bufferize::bufferize(Block *block,
LogicalResult
mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
BufferizationState &state) {
- OpBuilder &b = state.getBuilder();
+ RewriterBase &rewriter = state.getRewriter();
// Check if op has tensor results or operands.
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
@@ -496,8 +495,8 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
// Bufferize using `BufferizableOpInterface`. Interface implementations are
// responsible for bufferizing nested ops.
if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {
- b.setInsertionPoint(op);
- return bufferizableOp.bufferize(b, state);
+ rewriter.setInsertionPoint(op);
+ return bufferizableOp.bufferize(rewriter, state);
}
// `op` is an unbufferizable tensor op.
@@ -679,10 +678,9 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
}
// Insert to_memref op.
- OpBuilder &b = getBuilder();
- OpBuilder::InsertionGuard g(b);
- setInsertionPointAfter(b, tensor);
- return b.create<bufferization::ToMemrefOp>(
+ OpBuilder::InsertionGuard g(rewriter);
+ setInsertionPointAfter(rewriter, tensor);
+ return rewriter.create<bufferization::ToMemrefOp>(
tensor.getLoc(),
getDynamicMemRefType(tensor.getType().cast<RankedTensorType>()), tensor);
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
index 3419a6aa4492..eab925f02420 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
@@ -50,15 +50,14 @@ struct ToMemrefOpInterface
return OpResult();
}
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto toMemrefOp = cast<bufferization::ToMemrefOp>(op);
// Fold to_memref(to_tensor(x)) to x.
if (auto toTensorOp =
toMemrefOp.tensor().getDefiningOp<bufferization::ToTensorOp>()) {
- toMemrefOp.replaceAllUsesWith(toTensorOp.memref());
- toMemrefOp.erase();
+ rewriter.replaceOp(toMemrefOp, toTensorOp.memref());
return success();
}
@@ -86,7 +85,7 @@ struct ToMemrefOpInterface
struct ToTensorOpInterface
: public BufferizableOpInterface::ExternalModel<ToTensorOpInterface,
bufferization::ToTensorOp> {
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index c46746f6813f..66adbe7d1fc8 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -651,7 +651,8 @@ annotateOpsWithBufferizationMarkers(Operation *op,
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
Operation *op, std::unique_ptr<BufferizationOptions> options) {
- BufferizationState state(op, *options);
+ IRRewriter rewriter(op->getContext());
+ BufferizationState state(op, *options, rewriter);
return runComprehensiveBufferize(op, *options, state);
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 190f0fea5108..9977e46b6878 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -23,11 +23,11 @@ namespace {
// TODO: Ops in the linalg dialect can directly implement this interface.
/// Generic conversion for any LinalgOp on tensors.
-static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
+static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
BufferizationState &state) {
// Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(op);
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
// Nothing to do. This op is already bufferized.
if (op.hasBufferSemantics())
@@ -63,9 +63,9 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
// Set insertion point now that potential alloc/dealloc are introduced.
- b.setInsertionPoint(op);
- auto bufferizedOp = cast<LinalgOp>(
- op.clone(b, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands));
+ rewriter.setInsertionPoint(op);
+ auto bufferizedOp = cast<LinalgOp>(op.clone(
+ rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands));
// Replace the results of the old op with the new output buffers.
state.replaceOp(op, newOutputBuffers);
@@ -177,9 +177,9 @@ struct LinalgOpInterface
return BufferRelation::Equivalent;
}
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
- return bufferizeLinalgOp(b, cast<LinalgOp>(op), state);
+ return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state);
}
};
@@ -192,7 +192,7 @@ struct InitTensorOpInterface
return false;
}
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto initTensorOp = cast<linalg::InitTensorOp>(op);
@@ -200,7 +200,7 @@ struct InitTensorOpInterface
if (initTensorOp->getUses().empty())
return success();
- Value alloc = state.createAllocDeallocPair(b, initTensorOp->getLoc(),
+ Value alloc = state.createAllocDeallocPair(rewriter, initTensorOp->getLoc(),
initTensorOp.result());
state.replaceOp(op, alloc);
return success();
@@ -251,15 +251,10 @@ struct TiledLoopOpInterface
bool isAllocationHoistingBarrier(Operation *op) const { return true; }
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
- // Use IRRewriter instead of OpBuilder because it has additional helper
- // functions.
- IRRewriter rewriter(op->getContext());
- rewriter.setInsertionPoint(tiledLoopOp);
-
// Compute new inputs, outputs and results.
SmallVector<Value> newInputs, newOutputs, newResults;
for (Value value : tiledLoopOp.inputs()) {
@@ -358,7 +353,7 @@ struct YieldOpInterface
return OpResult();
}
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto yieldOp = cast<linalg::YieldOp>(op);
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index e7a5330ef399..d622245718d6 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -725,7 +725,8 @@ static void annotateOpsWithBufferizationMarkers(FuncOp funcOp,
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
ModuleOp moduleOp, std::unique_ptr<BufferizationOptions> options) {
- BufferizationState state(moduleOp, *options);
+ IRRewriter rewriter(moduleOp.getContext());
+ BufferizationState state(moduleOp, *options, rewriter);
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 5db5deb6aee6..4b5eb1848ff7 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -60,7 +60,7 @@ struct ExecuteRegionOpInterface
return true;
}
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
// TODO: Add bufferization support when needed. scf.execute_region should be
// bufferized similar to scf.if.
@@ -135,15 +135,10 @@ struct IfOpInterface
return true;
}
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto ifOp = cast<scf::IfOp>(op);
- // Use IRRewriter instead of OpBuilder because it has additional helper
- // functions.
- IRRewriter rewriter(op->getContext());
- rewriter.setInsertionPoint(ifOp);
-
// Compute new types of the bufferized scf.if op.
SmallVector<Type> newTypes;
for (Type returnType : ifOp->getResultTypes()) {
@@ -276,16 +271,11 @@ struct ForOpInterface
return true;
}
- LogicalResult bufferize(Operation *op, OpBuilder & /*b*/,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto forOp = cast<scf::ForOp>(op);
Block *oldLoopBody = &forOp.getLoopBody().front();
- // Use IRRewriter instead of OpBuilder because it has additional helper
- // functions.
- IRRewriter rewriter(op->getContext());
- rewriter.setInsertionPoint(forOp);
-
// Indices of all iter_args that have tensor type. These are the ones that
// are bufferized.
DenseSet<int64_t> indices;
@@ -438,7 +428,7 @@ struct YieldOpInterface
return OpResult();
}
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto yieldOp = cast<scf::YieldOp>(op);
if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>(
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 30ca9ed0a78b..c837986cdeb2 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -61,7 +61,7 @@ struct CastOpInterface
return BufferRelation::Equivalent;
}
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto castOp = cast<tensor::CastOp>(op);
@@ -82,7 +82,8 @@ struct CastOpInterface
: MemRefLayoutAttrInterface();
Type memRefType = getContiguousOrUnrankedMemRefType(
castOp.getResult().getType(), layout, memorySpace);
- state.replaceOpWithNewOp<memref::CastOp>(b, op, memRefType, resultBuffer);
+ state.replaceOpWithNewOp<memref::CastOp>(rewriter, op, memRefType,
+ resultBuffer);
return success();
}
};
@@ -105,13 +106,13 @@ struct DimOpInterface
return OpResult();
}
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto dimOp = cast<tensor::DimOp>(op);
if (!dimOp.source().getType().isa<RankedTensorType>())
return dimOp.emitError("unranked tensor not supported");
Value v = state.lookupBuffer(dimOp.source());
- state.replaceOpWithNewOp<memref::DimOp>(b, op, v, dimOp.index());
+ state.replaceOpWithNewOp<memref::DimOp>(rewriter, op, v, dimOp.index());
return success();
}
};
@@ -142,7 +143,7 @@ struct ExtractSliceOpInterface
return BufferRelation::None;
}
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
Location loc = extractSliceOp.getLoc();
@@ -155,7 +156,8 @@ struct ExtractSliceOpInterface
bool inplace = state.isInPlace(extractSliceOp->getResult(0));
Value alloc;
if (!inplace)
- alloc = state.createAllocDeallocPair(b, loc, extractSliceOp.result());
+ alloc =
+ state.createAllocDeallocPair(rewriter, loc, extractSliceOp.result());
// Bufferize to subview.
auto subviewMemRefType =
@@ -164,7 +166,7 @@ struct ExtractSliceOpInterface
extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
extractSliceOp.getMixedStrides())
.cast<MemRefType>();
- Value subView = b.create<memref::SubViewOp>(
+ Value subView = rewriter.create<memref::SubViewOp>(
loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
@@ -172,7 +174,7 @@ struct ExtractSliceOpInterface
if (!inplace) {
// Do not copy if the copied data is never read.
if (state.isValueRead(extractSliceOp.result()))
- state.createMemCpy(b, extractSliceOp.getLoc(), subView, alloc);
+ state.createMemCpy(rewriter, extractSliceOp.getLoc(), subView, alloc);
subView = alloc;
}
@@ -199,11 +201,11 @@ struct ExtractOpInterface
return OpResult();
}
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto extractOp = cast<tensor::ExtractOp>(op);
Value srcMemref = state.lookupBuffer(extractOp.tensor());
- state.replaceOpWithNewOp<memref::LoadOp>(b, op, srcMemref,
+ state.replaceOpWithNewOp<memref::LoadOp>(rewriter, op, srcMemref,
extractOp.indices());
return success();
}
@@ -235,13 +237,13 @@ struct InsertOpInterface
return {&op->getOpOperand(1) /*dest*/};
}
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto insertOp = cast<tensor::InsertOp>(op);
Location loc = insertOp.getLoc();
Value destMemref = state.getResultBuffer(insertOp->getOpResult(0));
- b.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
- insertOp.indices());
+ rewriter.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
+ insertOp.indices());
state.replaceOp(op, destMemref);
return success();
}
@@ -407,7 +409,7 @@ struct InsertSliceOpInterface
return false;
}
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
// insert_slice ops arise from tiling and bufferizing them out-of-place is
// generally a deal breaker. When used with loops, this ends up cloning the
@@ -434,12 +436,12 @@ struct InsertSliceOpInterface
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
insertSliceOp.getMixedStrides())
.cast<MemRefType>();
- Value subView = b.create<memref::SubViewOp>(
+ Value subView = rewriter.create<memref::SubViewOp>(
loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
// Copy tensor.
Value srcMemref = state.lookupBuffer(insertSliceOp.source());
- state.createMemCpy(b, insertSliceOp.getLoc(), srcMemref, subView);
+ state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView);
}
state.replaceOp(op, dstMemref);
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index 50ceb5aa77c9..73d89bc549fd 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -39,7 +39,7 @@ struct TransferReadOpInterface
return OpResult();
}
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto readOp = cast<vector::TransferReadOp>(op);
assert(readOp.getShapedType().isa<TensorType>() &&
@@ -47,7 +47,7 @@ struct TransferReadOpInterface
// TransferReadOp always reads from the bufferized op.source().
Value buffer = state.lookupBuffer(readOp.source());
- Value read = b.create<vector::TransferReadOp>(
+ Value read = rewriter.create<vector::TransferReadOp>(
readOp.getLoc(), readOp.getVectorType(), buffer, readOp.indices(),
readOp.permutation_map(), readOp.padding(), readOp.mask(),
readOp.in_boundsAttr());
@@ -86,7 +86,7 @@ struct TransferWriteOpInterface
return BufferRelation::Equivalent;
}
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto writeOp = cast<vector::TransferWriteOp>(op);
assert(writeOp.getShapedType().isa<TensorType>() &&
@@ -98,7 +98,7 @@ struct TransferWriteOpInterface
Value resultBuffer = state.getResultBuffer(op->getResult(0));
if (!resultBuffer)
return failure();
- b.create<vector::TransferWriteOp>(
+ rewriter.create<vector::TransferWriteOp>(
writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
state.replaceOp(op, resultBuffer);
More information about the Mlir-commits
mailing list