[Mlir-commits] [mlir] ed5e359 - [mlir][linalg][bufferize][NFC] Remove RewriterBase from BufferizationState
Matthias Springer
llvmlistbot at llvm.org
Wed Jan 5 07:05:05 PST 2022
Author: Matthias Springer
Date: 2022-01-06T00:04:43+09:00
New Revision: ed5e3590a3b821c7086226a6ca5679d0b48d7bec
URL: https://github.com/llvm/llvm-project/commit/ed5e3590a3b821c7086226a6ca5679d0b48d7bec
DIFF: https://github.com/llvm/llvm-project/commit/ed5e3590a3b821c7086226a6ca5679d0b48d7bec.diff
LOG: [mlir][linalg][bufferize][NFC] Remove RewriterBase from BufferizationState
This change simplifies BufferizationState. Having `rewriter` in BufferizationState could be confusing to users because a rewriter is also passed to each `bufferize` function and it is not obvious (by looking at the API) that these two rewriters are the same.
Differential Revision: https://reviews.llvm.org/D116444
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index cfafc6b33bb70..3ec15fb301988 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -297,8 +297,7 @@ struct DialectBufferizationState {
/// * `replaceOp` replaces an op with new values.
class BufferizationState {
public:
- BufferizationState(Operation *op, const BufferizationOptions &options,
- RewriterBase &rewriter);
+ BufferizationState(Operation *op, const BufferizationOptions &options);
// BufferizationState should be passed as a reference.
BufferizationState(const BufferizationState &) = delete;
@@ -384,7 +383,7 @@ class BufferizationState {
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
/// must be replaced with memref values.
- void replaceOp(Operation *op, ValueRange values);
+ void replaceOp(RewriterBase &rewriter, Operation *op, ValueRange values);
/// Replace an op with a new op. Tensor OpResults must be replaced with memref
/// values.
@@ -393,13 +392,13 @@ class BufferizationState {
Args &&...args) {
Operation *newOp =
rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
- replaceOp(op, newOp->getResults());
+ replaceOp(rewriter, op, newOp->getResults());
return cast<OpTy>(newOp);
}
/// Lookup the memref buffer that is associated to the given tensor value.
/// Asserts if no buffer is associated.
- Value lookupBuffer(Value tensor);
+ Value lookupBuffer(RewriterBase &rewriter, Value tensor);
/// Return `true` if the given OpResult has been decided to bufferize inplace.
bool isInPlace(OpResult opResult) const;
@@ -407,7 +406,7 @@ class BufferizationState {
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
- Value getResultBuffer(OpResult result);
+ Value getResultBuffer(RewriterBase &rewriter, OpResult result);
/// Return dialect-specific bufferization state.
template <typename StateT> StateT &getDialectState(StringRef name) {
@@ -420,9 +419,6 @@ class BufferizationState {
/// Return a reference to the BufferizationOptions.
const BufferizationOptions &getOptions() const { return options; }
- /// Return a reference to the rewriter.
- RewriterBase &getRewriter() { return rewriter; }
-
private:
friend LogicalResult
runComprehensiveBufferize(Operation *op, const BufferizationOptions &options,
@@ -441,21 +437,21 @@ class BufferizationState {
/// A reference to current bufferization options.
const BufferizationOptions &options;
-
- /// The OpBuilder used during bufferization.
- RewriterBase &rewriter;
};
/// Bufferize all ops in the given region.
-LogicalResult bufferize(Region *region, BufferizationState &state);
+LogicalResult bufferize(RewriterBase &rewriter, Region *region,
+ BufferizationState &state);
/// Bufferize all ops in the given block.
-LogicalResult bufferize(Block *block, BufferizationState &state);
+LogicalResult bufferize(RewriterBase &rewriter, Block *block,
+ BufferizationState &state);
/// Bufferize the given op. If the op has no tensor OpOperands/OpResults, this
/// function returns immediately. Otherwise, it calls the `bufferize` interface
/// method of `BufferizableOpInterface`.
-LogicalResult bufferize(Operation *op, BufferizationState &state);
+LogicalResult bufferize(RewriterBase &rewriter, Operation *op,
+ BufferizationState &state);
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
/// with the same shape as `shapedType` and specified `layout` and
@@ -535,7 +531,7 @@ struct AllocationHoistingBarrierOnly
return op->emitError() << "unsupported op with tensors";
for (Region ®ion : op->getRegions())
- if (failed(comprehensive_bufferize::bufferize(®ion, state)))
+ if (failed(comprehensive_bufferize::bufferize(rewriter, ®ion, state)))
return failure();
return success();
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index a639711196b46..816fafb0691ca 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -333,8 +333,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
}
mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
- Operation *op, const BufferizationOptions &options, RewriterBase &rewriter)
- : aliasInfo(op), options(options), rewriter(rewriter) {
+ Operation *op, const BufferizationOptions &options)
+ : aliasInfo(op), options(options) {
// Set up alias sets for OpResults that must bufferize in-place. This should
// be done before making any other bufferization decisions.
op->walk([&](BufferizableOpInterface bufferizableOp) {
@@ -360,14 +360,14 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
- getResultBuffer(OpResult result) {
+ getResultBuffer(RewriterBase &rewriter, OpResult result) {
OpBuilder::InsertionGuard guard(rewriter);
Operation *op = result.getOwner();
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
OpOperand *opOperand = aliasingOperands.front();
Value operand = opOperand->get();
- Value operandBuffer = lookupBuffer(operand);
+ Value operandBuffer = lookupBuffer(rewriter, operand);
// Make sure that all OpOperands are the same buffer. If this is not the case,
// we would have to materialize a memref value.
// TODO: Should be looking for checking for "equivalent buffers" instead of
@@ -375,7 +375,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
// set up yet.
if (aliasingOperands.size() > 1 &&
!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
- return lookupBuffer(o->get()) == operandBuffer;
+ return lookupBuffer(rewriter, o->get()) == operandBuffer;
})) {
op->emitError("result buffer is ambiguous");
return Value();
@@ -424,7 +424,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
}
void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
- Operation *op, ValueRange values) {
+ RewriterBase &rewriter, Operation *op, ValueRange values) {
OpBuilder::InsertionGuard g(rewriter);
// Replace all OpResults with the given values.
@@ -453,18 +453,16 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
rewriter.eraseOp(op);
}
-LogicalResult
-mlir::linalg::comprehensive_bufferize::bufferize(Region *region,
- BufferizationState &state) {
+LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
+ RewriterBase &rewriter, Region *region, BufferizationState &state) {
for (Block &block : *region)
- if (failed(bufferize(&block, state)))
+ if (failed(bufferize(rewriter, &block, state)))
return failure();
return success();
}
-LogicalResult
-mlir::linalg::comprehensive_bufferize::bufferize(Block *block,
- BufferizationState &state) {
+LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
+ RewriterBase &rewriter, Block *block, BufferizationState &state) {
// Ops may get deleted during the traversal, so do not iterate over `block`
// directly.
SmallVector<Operation *> ops;
@@ -472,16 +470,13 @@ mlir::linalg::comprehensive_bufferize::bufferize(Block *block,
for (Operation &op : *block)
ops.push_back(&op);
for (Operation *op : ops)
- if (failed(bufferize(op, state)))
+ if (failed(bufferize(rewriter, op, state)))
return failure();
return success();
}
-LogicalResult
-mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
- BufferizationState &state) {
- RewriterBase &rewriter = state.getRewriter();
-
+LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
+ RewriterBase &rewriter, Operation *op, BufferizationState &state) {
// Check if op has tensor results or operands.
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
@@ -505,7 +500,7 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
// Bufferize all regions.
for (Region ®ion : op->getRegions())
- if (failed(bufferize(®ion, state)))
+ if (failed(bufferize(rewriter, ®ion, state)))
return failure();
return success();
@@ -654,7 +649,7 @@ bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
}
Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
- Value tensor) {
+ RewriterBase &rewriter, Value tensor) {
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
// Replace "%t = to_tensor %m" with %m.
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 66adbe7d1fc8c..1d7ebfa39988b 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -651,8 +651,7 @@ annotateOpsWithBufferizationMarkers(Operation *op,
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
Operation *op, std::unique_ptr<BufferizationOptions> options) {
- IRRewriter rewriter(op->getContext());
- BufferizationState state(op, *options, rewriter);
+ BufferizationState state(op, *options);
return runComprehensiveBufferize(op, *options, state);
}
@@ -660,6 +659,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
Operation *op, const BufferizationOptions &options,
BufferizationState &state) {
+ IRRewriter rewriter(op->getContext());
DominanceInfo domInfo(op);
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
@@ -690,7 +690,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
}
// Bufferize the op and its nested ops.
- if (failed(bufferize(op, state)))
+ if (failed(bufferize(rewriter, op, state)))
return failure();
return success();
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 9977e46b68782..dd9f123117541 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -45,14 +45,14 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
newInputBuffers.push_back(opOperand->get());
continue;
}
- newInputBuffers.push_back(state.lookupBuffer(opOperand->get()));
+ newInputBuffers.push_back(state.lookupBuffer(rewriter, opOperand->get()));
}
SmallVector<Value> newOutputBuffers;
for (OpOperand *opOperand : op.getOutputOperands()) {
OpResult opResult = op.getTiedOpResult(opOperand);
assert(opResult && "could not find correspond OpResult");
- Value resultBuffer = state.getResultBuffer(opResult);
+ Value resultBuffer = state.getResultBuffer(rewriter, opResult);
if (!resultBuffer)
return failure();
newOutputBuffers.push_back(resultBuffer);
@@ -68,9 +68,10 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands));
// Replace the results of the old op with the new output buffers.
- state.replaceOp(op, newOutputBuffers);
+ state.replaceOp(rewriter, op, newOutputBuffers);
- return comprehensive_bufferize::bufferize(bufferizedOp.getBlock(), state);
+ return comprehensive_bufferize::bufferize(rewriter, bufferizedOp.getBlock(),
+ state);
}
/// Linalg OpResults usually bufferize inplace with their tied (output
@@ -202,7 +203,7 @@ struct InitTensorOpInterface
Value alloc = state.createAllocDeallocPair(rewriter, initTensorOp->getLoc(),
initTensorOp.result());
- state.replaceOp(op, alloc);
+ state.replaceOp(rewriter, op, alloc);
return success();
}
};
@@ -259,7 +260,7 @@ struct TiledLoopOpInterface
SmallVector<Value> newInputs, newOutputs, newResults;
for (Value value : tiledLoopOp.inputs()) {
if (value.getType().isa<TensorType>()) {
- newInputs.push_back(state.lookupBuffer(value));
+ newInputs.push_back(state.lookupBuffer(rewriter, value));
} else {
newInputs.push_back(value);
}
@@ -267,8 +268,8 @@ struct TiledLoopOpInterface
int nextResultNum = 0;
for (Value value : tiledLoopOp.outputs()) {
if (value.getType().isa<TensorType>()) {
- Value buffer =
- state.getResultBuffer(tiledLoopOp->getResult(nextResultNum++));
+ Value buffer = state.getResultBuffer(
+ rewriter, tiledLoopOp->getResult(nextResultNum++));
newOutputs.push_back(buffer);
newResults.push_back(buffer);
} else {
@@ -328,10 +329,11 @@ struct TiledLoopOpInterface
rewriter.eraseOp(oldTerminator);
// Replace results and delete old op.
- state.replaceOp(op, newResults);
+ state.replaceOp(rewriter, op, newResults);
// Bufferize loop body.
- return comprehensive_bufferize::bufferize(newTiledLoopOp.getBody(), state);
+ return comprehensive_bufferize::bufferize(rewriter,
+ newTiledLoopOp.getBody(), state);
}
};
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index d622245718d65..ee4ced0b17396 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -219,6 +219,7 @@ static void equivalenceAnalysis(FuncOp funcOp,
/// originate from an op with an Alloc effect, they could be hoisted in the
/// future.
static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
+ RewriterBase &rewriter,
BufferizationState &state) {
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
@@ -277,7 +278,8 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
continue;
// Cast values at the call site if necessary.
- returnValues.push_back(getNonCastedValue(state.lookupBuffer(returnVal)));
+ returnValues.push_back(
+ getNonCastedValue(state.lookupBuffer(rewriter, returnVal)));
}
// 2. Rewrite the terminator without the inPlace bufferizable values.
@@ -510,7 +512,7 @@ struct CallOpInterface
/// In a first approximation, all the function arguments of a FuncOp are
/// marked inplaceable. For now, it is the responsibility of the `callOp`
/// bufferization to allow FuncOp that are inplaceable to write inPlace.
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
CallOp callOp = cast<CallOp>(op);
FuncOp funcOp = getCalledFunction(callOp);
@@ -552,13 +554,13 @@ struct CallOpInterface
moduleState
.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
- Value buffer = state.lookupBuffer(callOp->getOperand(idx));
+ Value buffer = state.lookupBuffer(rewriter, callOp->getOperand(idx));
// Add a ToTensorOp to kill all uses of the CallOp return.
// Replace all uses of the CallOp results so we can erase the CallOp.
// This ToTensorOp must fold/DCE away or bufferization should be
// considered failed.
- Value toTensorOp =
- b.create<bufferization::ToTensorOp>(callOp.getLoc(), buffer);
+ Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
+ callOp.getLoc(), buffer);
oldRes.replaceAllUsesWith(toTensorOp);
continue;
}
@@ -588,7 +590,7 @@ struct CallOpInterface
// Tensor operands are guaranteed to have been buferized.
int64_t idx = opOperand.getOperandNumber();
- Value buffer = state.lookupBuffer(tensorOperand);
+ Value buffer = state.lookupBuffer(rewriter, tensorOperand);
// Caller / callee type mistmatch is handled with a CastOp.
auto memRefType = bufferizedFuncType.getInput(idx);
@@ -598,16 +600,16 @@ struct CallOpInterface
// that will either canonicalize away or fail compilation until we can do
// something better.
if (buffer.getType() != memRefType) {
- Value castBuffer =
- b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
+ Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
+ memRefType, buffer);
buffer = castBuffer;
}
newOperands.push_back(buffer);
}
// 4. Create the new CallOp.
- Operation *newCallOp = b.create<CallOp>(callOp.getLoc(), funcOp.sym_name(),
- resultTypes, newOperands);
+ Operation *newCallOp = rewriter.create<CallOp>(
+ callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
newCallOp->setAttrs(callOp->getAttrs());
// 5. Delete the op at the end of bufferization.
@@ -635,7 +637,7 @@ struct ReturnOpInterface
return OpResult();
}
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto returnOp = cast<ReturnOp>(op);
assert(isa<FuncOp>(returnOp->getParentOp()) &&
@@ -645,9 +647,9 @@ struct ReturnOpInterface
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
if (!tensorType)
continue;
- Value v = state.lookupBuffer(operand.get());
- Value returnTensor = b.create<bufferization::ToTensorOp>(
- returnOp.getLoc(), v);
+ Value v = state.lookupBuffer(rewriter, operand.get());
+ Value returnTensor =
+ rewriter.create<bufferization::ToTensorOp>(returnOp.getLoc(), v);
operand.set(returnTensor);
}
return success();
@@ -656,12 +658,12 @@ struct ReturnOpInterface
struct FuncOpInterface
: public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
- LogicalResult bufferize(Operation *op, OpBuilder &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto funcOp = cast<FuncOp>(op);
// Bufferize function body.
- return comprehensive_bufferize::bufferize(&funcOp.body(), state);
+ return comprehensive_bufferize::bufferize(rewriter, &funcOp.body(), state);
}
/// Return `true` if the given function argument is writable.
@@ -726,7 +728,7 @@ static void annotateOpsWithBufferizationMarkers(FuncOp funcOp,
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
ModuleOp moduleOp, std::unique_ptr<BufferizationOptions> options) {
IRRewriter rewriter(moduleOp.getContext());
- BufferizationState state(moduleOp, *options, rewriter);
+ BufferizationState state(moduleOp, *options);
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
@@ -766,7 +768,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
for (FuncOp funcOp : moduleState.orderedFuncOps) {
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
// would be invalidated.
- if (failed(bufferizeFuncOpBoundary(funcOp, state)))
+ if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, state)))
return failure();
if (!options->allowReturnMemref &&
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 4b5eb1848ff72..d008607ed4c1a 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -70,8 +70,8 @@ struct ExecuteRegionOpInterface
if (hasTensorReturnType)
return op->emitError(
"scf.execute_region with tensor result not supported");
- return comprehensive_bufferize::bufferize(&executeRegionOp.getRegion(),
- state);
+ return comprehensive_bufferize::bufferize(
+ rewriter, &executeRegionOp.getRegion(), state);
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
@@ -194,12 +194,14 @@ struct IfOpInterface
}
// Replace op results.
- state.replaceOp(op, newIfOp->getResults());
+ state.replaceOp(rewriter, op, newIfOp->getResults());
// Bufferize then/else blocks.
- if (failed(comprehensive_bufferize::bufferize(newIfOp.thenBlock(), state)))
+ if (failed(comprehensive_bufferize::bufferize(rewriter, newIfOp.thenBlock(),
+ state)))
return failure();
- if (failed(comprehensive_bufferize::bufferize(newIfOp.elseBlock(), state)))
+ if (failed(comprehensive_bufferize::bufferize(rewriter, newIfOp.elseBlock(),
+ state)))
return failure();
return success();
@@ -299,7 +301,7 @@ struct ForOpInterface
// Construct a new scf.for op with memref instead of tensor values.
SmallVector<Value> initArgs =
convert(forOp.getInitArgs(), [&](Value val, int64_t index) {
- return state.getResultBuffer(forOp->getOpResult(index));
+ return state.getResultBuffer(rewriter, forOp->getOpResult(index));
});
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
@@ -333,10 +335,10 @@ struct ForOpInterface
yieldOp.getResultsMutable().assign(yieldValues);
// Replace loop results.
- state.replaceOp(op, newForOp->getResults());
+ state.replaceOp(rewriter, op, newForOp->getResults());
// Bufferize loop body.
- if (failed(comprehensive_bufferize::bufferize(loopBody, state)))
+ if (failed(comprehensive_bufferize::bufferize(rewriter, loopBody, state)))
return failure();
return success();
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index c837986cdeb29..0f91e52a5227e 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -65,7 +65,7 @@ struct CastOpInterface
BufferizationState &state) const {
auto castOp = cast<tensor::CastOp>(op);
- Value resultBuffer = state.getResultBuffer(castOp->getResult(0));
+ Value resultBuffer = state.getResultBuffer(rewriter, castOp->getResult(0));
if (!resultBuffer)
return failure();
Type sourceType = resultBuffer.getType();
@@ -111,7 +111,7 @@ struct DimOpInterface
auto dimOp = cast<tensor::DimOp>(op);
if (!dimOp.source().getType().isa<RankedTensorType>())
return dimOp.emitError("unranked tensor not supported");
- Value v = state.lookupBuffer(dimOp.source());
+ Value v = state.lookupBuffer(rewriter, dimOp.source());
state.replaceOpWithNewOp<memref::DimOp>(rewriter, op, v, dimOp.index());
return success();
}
@@ -147,7 +147,7 @@ struct ExtractSliceOpInterface
BufferizationState &state) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
Location loc = extractSliceOp.getLoc();
- Value srcMemref = state.lookupBuffer(extractSliceOp.source());
+ Value srcMemref = state.lookupBuffer(rewriter, extractSliceOp.source());
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
auto dstTensorType =
extractSliceOp.result().getType().cast<RankedTensorType>();
@@ -178,7 +178,7 @@ struct ExtractSliceOpInterface
subView = alloc;
}
- state.replaceOp(op, subView);
+ state.replaceOp(rewriter, op, subView);
return success();
}
};
@@ -204,7 +204,7 @@ struct ExtractOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto extractOp = cast<tensor::ExtractOp>(op);
- Value srcMemref = state.lookupBuffer(extractOp.tensor());
+ Value srcMemref = state.lookupBuffer(rewriter, extractOp.tensor());
state.replaceOpWithNewOp<memref::LoadOp>(rewriter, op, srcMemref,
extractOp.indices());
return success();
@@ -241,10 +241,11 @@ struct InsertOpInterface
BufferizationState &state) const {
auto insertOp = cast<tensor::InsertOp>(op);
Location loc = insertOp.getLoc();
- Value destMemref = state.getResultBuffer(insertOp->getOpResult(0));
+ Value destMemref =
+ state.getResultBuffer(rewriter, insertOp->getOpResult(0));
rewriter.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
insertOp.indices());
- state.replaceOp(op, destMemref);
+ state.replaceOp(rewriter, op, destMemref);
return success();
}
@@ -421,7 +422,8 @@ struct InsertSliceOpInterface
TensorBufferizationState &tensorState = getTensorBufferizationState(state);
// When bufferizing out-of-place, `getResultBuffer` allocates.
- Value dstMemref = state.getResultBuffer(insertSliceOp->getResult(0));
+ Value dstMemref =
+ state.getResultBuffer(rewriter, insertSliceOp->getResult(0));
if (!dstMemref)
return failure();
@@ -440,11 +442,11 @@ struct InsertSliceOpInterface
loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
// Copy tensor.
- Value srcMemref = state.lookupBuffer(insertSliceOp.source());
+ Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source());
state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView);
}
- state.replaceOp(op, dstMemref);
+ state.replaceOp(rewriter, op, dstMemref);
return success();
}
};
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index 73d89bc549fd8..c8d335e66bc8c 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -46,12 +46,12 @@ struct TransferReadOpInterface
"only tensor types expected");
// TransferReadOp always reads from the bufferized op.source().
- Value buffer = state.lookupBuffer(readOp.source());
+ Value buffer = state.lookupBuffer(rewriter, readOp.source());
Value read = rewriter.create<vector::TransferReadOp>(
readOp.getLoc(), readOp.getVectorType(), buffer, readOp.indices(),
readOp.permutation_map(), readOp.padding(), readOp.mask(),
readOp.in_boundsAttr());
- state.replaceOp(op, read);
+ state.replaceOp(rewriter, op, read);
return success();
}
};
@@ -95,13 +95,13 @@ struct TransferWriteOpInterface
// Create a new transfer_write on buffer that doesn't have a return value.
// Leave the previous transfer_write to dead code as it still has uses at
// this point.
- Value resultBuffer = state.getResultBuffer(op->getResult(0));
+ Value resultBuffer = state.getResultBuffer(rewriter, op->getResult(0));
if (!resultBuffer)
return failure();
rewriter.create<vector::TransferWriteOp>(
writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
- state.replaceOp(op, resultBuffer);
+ state.replaceOp(rewriter, op, resultBuffer);
return success();
}
More information about the Mlir-commits
mailing list