[Mlir-commits] [mlir] ed5e359 - [mlir][linalg][bufferize][NFC] Remove RewriterBase from BufferizationState

Matthias Springer llvmlistbot at llvm.org
Wed Jan 5 07:05:05 PST 2022


Author: Matthias Springer
Date: 2022-01-06T00:04:43+09:00
New Revision: ed5e3590a3b821c7086226a6ca5679d0b48d7bec

URL: https://github.com/llvm/llvm-project/commit/ed5e3590a3b821c7086226a6ca5679d0b48d7bec
DIFF: https://github.com/llvm/llvm-project/commit/ed5e3590a3b821c7086226a6ca5679d0b48d7bec.diff

LOG: [mlir][linalg][bufferize][NFC] Remove RewriterBase from BufferizationState

This change simplifies BufferizationState. Having `rewriter` in BufferizationState could be confusing to users because a rewriter is also passed to each `bufferize` function and it is not obvious (by looking at the API) that these two rewriters are the same.

Differential Revision: https://reviews.llvm.org/D116444

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index cfafc6b33bb70..3ec15fb301988 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -297,8 +297,7 @@ struct DialectBufferizationState {
 /// * `replaceOp` replaces an op with new values.
 class BufferizationState {
 public:
-  BufferizationState(Operation *op, const BufferizationOptions &options,
-                     RewriterBase &rewriter);
+  BufferizationState(Operation *op, const BufferizationOptions &options);
 
   // BufferizationState should be passed as a reference.
   BufferizationState(const BufferizationState &) = delete;
@@ -384,7 +383,7 @@ class BufferizationState {
 
   /// Replace an op with replacement values. The op is deleted. Tensor OpResults
   /// must be replaced with memref values.
-  void replaceOp(Operation *op, ValueRange values);
+  void replaceOp(RewriterBase &rewriter, Operation *op, ValueRange values);
 
   /// Replace an op with a new op. Tensor OpResults must be replaced with memref
   /// values.
@@ -393,13 +392,13 @@ class BufferizationState {
                           Args &&...args) {
     Operation *newOp =
         rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
-    replaceOp(op, newOp->getResults());
+    replaceOp(rewriter, op, newOp->getResults());
     return cast<OpTy>(newOp);
   }
 
   /// Lookup the memref buffer that is associated to the given tensor value.
   /// Asserts if no buffer is associated.
-  Value lookupBuffer(Value tensor);
+  Value lookupBuffer(RewriterBase &rewriter, Value tensor);
 
   /// Return `true` if the given OpResult has been decided to bufferize inplace.
   bool isInPlace(OpResult opResult) const;
@@ -407,7 +406,7 @@ class BufferizationState {
   /// Return the result buffer (memref) for a given OpResult (tensor). Allocate
   /// a new buffer and copy over data from the existing buffer if out-of-place
   /// bufferization is necessary.
-  Value getResultBuffer(OpResult result);
+  Value getResultBuffer(RewriterBase &rewriter, OpResult result);
 
   /// Return dialect-specific bufferization state.
   template <typename StateT> StateT &getDialectState(StringRef name) {
@@ -420,9 +419,6 @@ class BufferizationState {
   /// Return a reference to the BufferizationOptions.
   const BufferizationOptions &getOptions() const { return options; }
 
-  /// Return a reference to the rewriter.
-  RewriterBase &getRewriter() { return rewriter; }
-
 private:
   friend LogicalResult
   runComprehensiveBufferize(Operation *op, const BufferizationOptions &options,
@@ -441,21 +437,21 @@ class BufferizationState {
 
   /// A reference to current bufferization options.
   const BufferizationOptions &options;
-
-  /// The OpBuilder used during bufferization.
-  RewriterBase &rewriter;
 };
 
 /// Bufferize all ops in the given region.
-LogicalResult bufferize(Region *region, BufferizationState &state);
+LogicalResult bufferize(RewriterBase &rewriter, Region *region,
+                        BufferizationState &state);
 
 /// Bufferize all ops in the given block.
-LogicalResult bufferize(Block *block, BufferizationState &state);
+LogicalResult bufferize(RewriterBase &rewriter, Block *block,
+                        BufferizationState &state);
 
 /// Bufferize the given op. If the op has no tensor OpOperands/OpResults, this
 /// function returns immediately. Otherwise, it calls the `bufferize` interface
 /// method of `BufferizableOpInterface`.
-LogicalResult bufferize(Operation *op, BufferizationState &state);
+LogicalResult bufferize(RewriterBase &rewriter, Operation *op,
+                        BufferizationState &state);
 
 /// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
 /// with the same shape as `shapedType` and specified `layout` and
@@ -535,7 +531,7 @@ struct AllocationHoistingBarrierOnly
         return op->emitError() << "unsupported op with tensors";
 
     for (Region &region : op->getRegions())
-      if (failed(comprehensive_bufferize::bufferize(&region, state)))
+      if (failed(comprehensive_bufferize::bufferize(rewriter, &region, state)))
         return failure();
 
     return success();

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index a639711196b46..816fafb0691ca 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -333,8 +333,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
 }
 
 mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
-    Operation *op, const BufferizationOptions &options, RewriterBase &rewriter)
-    : aliasInfo(op), options(options), rewriter(rewriter) {
+    Operation *op, const BufferizationOptions &options)
+    : aliasInfo(op), options(options) {
   // Set up alias sets for OpResults that must bufferize in-place. This should
   // be done before making any other bufferization decisions.
   op->walk([&](BufferizableOpInterface bufferizableOp) {
@@ -360,14 +360,14 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
 /// a new buffer and copy over data from the existing buffer if out-of-place
 /// bufferization is necessary.
 Value mlir::linalg::comprehensive_bufferize::BufferizationState::
-    getResultBuffer(OpResult result) {
+    getResultBuffer(RewriterBase &rewriter, OpResult result) {
   OpBuilder::InsertionGuard guard(rewriter);
   Operation *op = result.getOwner();
   SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
   assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
   OpOperand *opOperand = aliasingOperands.front();
   Value operand = opOperand->get();
-  Value operandBuffer = lookupBuffer(operand);
+  Value operandBuffer = lookupBuffer(rewriter, operand);
   // Make sure that all OpOperands are the same buffer. If this is not the case,
   // we would have to materialize a memref value.
   // TODO: Should be looking for checking for "equivalent buffers" instead of
@@ -375,7 +375,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
   // set up yet.
   if (aliasingOperands.size() > 1 &&
       !llvm::all_of(aliasingOperands, [&](OpOperand *o) {
-        return lookupBuffer(o->get()) == operandBuffer;
+        return lookupBuffer(rewriter, o->get()) == operandBuffer;
       })) {
     op->emitError("result buffer is ambiguous");
     return Value();
@@ -424,7 +424,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
 }
 
 void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
-    Operation *op, ValueRange values) {
+    RewriterBase &rewriter, Operation *op, ValueRange values) {
   OpBuilder::InsertionGuard g(rewriter);
 
   // Replace all OpResults with the given values.
@@ -453,18 +453,16 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
   rewriter.eraseOp(op);
 }
 
-LogicalResult
-mlir::linalg::comprehensive_bufferize::bufferize(Region *region,
-                                                 BufferizationState &state) {
+LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
+    RewriterBase &rewriter, Region *region, BufferizationState &state) {
   for (Block &block : *region)
-    if (failed(bufferize(&block, state)))
+    if (failed(bufferize(rewriter, &block, state)))
       return failure();
   return success();
 }
 
-LogicalResult
-mlir::linalg::comprehensive_bufferize::bufferize(Block *block,
-                                                 BufferizationState &state) {
+LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
+    RewriterBase &rewriter, Block *block, BufferizationState &state) {
   // Ops may get deleted during the traversal, so do not iterate over `block`
   // directly.
   SmallVector<Operation *> ops;
@@ -472,16 +470,13 @@ mlir::linalg::comprehensive_bufferize::bufferize(Block *block,
   for (Operation &op : *block)
     ops.push_back(&op);
   for (Operation *op : ops)
-    if (failed(bufferize(op, state)))
+    if (failed(bufferize(rewriter, op, state)))
       return failure();
   return success();
 }
 
-LogicalResult
-mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
-                                                 BufferizationState &state) {
-  RewriterBase &rewriter = state.getRewriter();
-
+LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
+    RewriterBase &rewriter, Operation *op, BufferizationState &state) {
   // Check if op has tensor results or operands.
   auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
@@ -505,7 +500,7 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
 
   // Bufferize all regions.
   for (Region &region : op->getRegions())
-    if (failed(bufferize(&region, state)))
+    if (failed(bufferize(rewriter, &region, state)))
       return failure();
 
   return success();
@@ -654,7 +649,7 @@ bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
 }
 
 Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
-    Value tensor) {
+    RewriterBase &rewriter, Value tensor) {
   assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
 
   // Replace "%t = to_tensor %m" with %m.

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 66adbe7d1fc8c..1d7ebfa39988b 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -651,8 +651,7 @@ annotateOpsWithBufferizationMarkers(Operation *op,
 
 LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     Operation *op, std::unique_ptr<BufferizationOptions> options) {
-  IRRewriter rewriter(op->getContext());
-  BufferizationState state(op, *options, rewriter);
+  BufferizationState state(op, *options);
   return runComprehensiveBufferize(op, *options, state);
 }
 
@@ -660,6 +659,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     Operation *op, const BufferizationOptions &options,
     BufferizationState &state) {
 
+  IRRewriter rewriter(op->getContext());
   DominanceInfo domInfo(op);
   BufferizationAliasInfo &aliasInfo = state.aliasInfo;
 
@@ -690,7 +690,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
   }
 
   // Bufferize the op and its nested ops.
-  if (failed(bufferize(op, state)))
+  if (failed(bufferize(rewriter, op, state)))
     return failure();
 
   return success();

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 9977e46b68782..dd9f123117541 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -45,14 +45,14 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
       newInputBuffers.push_back(opOperand->get());
       continue;
     }
-    newInputBuffers.push_back(state.lookupBuffer(opOperand->get()));
+    newInputBuffers.push_back(state.lookupBuffer(rewriter, opOperand->get()));
   }
 
   SmallVector<Value> newOutputBuffers;
   for (OpOperand *opOperand : op.getOutputOperands()) {
     OpResult opResult = op.getTiedOpResult(opOperand);
     assert(opResult && "could not find correspond OpResult");
-    Value resultBuffer = state.getResultBuffer(opResult);
+    Value resultBuffer = state.getResultBuffer(rewriter, opResult);
     if (!resultBuffer)
       return failure();
     newOutputBuffers.push_back(resultBuffer);
@@ -68,9 +68,10 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
       rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands));
 
   // Replace the results of the old op with the new output buffers.
-  state.replaceOp(op, newOutputBuffers);
+  state.replaceOp(rewriter, op, newOutputBuffers);
 
-  return comprehensive_bufferize::bufferize(bufferizedOp.getBlock(), state);
+  return comprehensive_bufferize::bufferize(rewriter, bufferizedOp.getBlock(),
+                                            state);
 }
 
 /// Linalg OpResults usually bufferize inplace with their tied (output
@@ -202,7 +203,7 @@ struct InitTensorOpInterface
 
     Value alloc = state.createAllocDeallocPair(rewriter, initTensorOp->getLoc(),
                                                initTensorOp.result());
-    state.replaceOp(op, alloc);
+    state.replaceOp(rewriter, op, alloc);
     return success();
   }
 };
@@ -259,7 +260,7 @@ struct TiledLoopOpInterface
     SmallVector<Value> newInputs, newOutputs, newResults;
     for (Value value : tiledLoopOp.inputs()) {
       if (value.getType().isa<TensorType>()) {
-        newInputs.push_back(state.lookupBuffer(value));
+        newInputs.push_back(state.lookupBuffer(rewriter, value));
       } else {
         newInputs.push_back(value);
       }
@@ -267,8 +268,8 @@ struct TiledLoopOpInterface
     int nextResultNum = 0;
     for (Value value : tiledLoopOp.outputs()) {
       if (value.getType().isa<TensorType>()) {
-        Value buffer =
-            state.getResultBuffer(tiledLoopOp->getResult(nextResultNum++));
+        Value buffer = state.getResultBuffer(
+            rewriter, tiledLoopOp->getResult(nextResultNum++));
         newOutputs.push_back(buffer);
         newResults.push_back(buffer);
       } else {
@@ -328,10 +329,11 @@ struct TiledLoopOpInterface
     rewriter.eraseOp(oldTerminator);
 
     // Replace results and delete old op.
-    state.replaceOp(op, newResults);
+    state.replaceOp(rewriter, op, newResults);
 
     // Bufferize loop body.
-    return comprehensive_bufferize::bufferize(newTiledLoopOp.getBody(), state);
+    return comprehensive_bufferize::bufferize(rewriter,
+                                              newTiledLoopOp.getBody(), state);
   }
 };
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index d622245718d65..ee4ced0b17396 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -219,6 +219,7 @@ static void equivalenceAnalysis(FuncOp funcOp,
 /// originate from an op with an Alloc effect, they could be hoisted in the
 /// future.
 static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
+                                             RewriterBase &rewriter,
                                              BufferizationState &state) {
   ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
 
@@ -277,7 +278,8 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
       continue;
 
     // Cast values at the call site if necessary.
-    returnValues.push_back(getNonCastedValue(state.lookupBuffer(returnVal)));
+    returnValues.push_back(
+        getNonCastedValue(state.lookupBuffer(rewriter, returnVal)));
   }
 
   // 2. Rewrite the terminator without the inPlace bufferizable values.
@@ -510,7 +512,7 @@ struct CallOpInterface
   /// In a first approximation, all the function arguments of a FuncOp are
   /// marked inplaceable. For now, it is the responsibility of the `callOp`
   /// bufferization to allow FuncOp that are inplaceable to write inPlace.
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     CallOp callOp = cast<CallOp>(op);
     FuncOp funcOp = getCalledFunction(callOp);
@@ -552,13 +554,13 @@ struct CallOpInterface
               moduleState
                   .equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
           Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
-          Value buffer = state.lookupBuffer(callOp->getOperand(idx));
+          Value buffer = state.lookupBuffer(rewriter, callOp->getOperand(idx));
           // Add a ToTensorOp to kill all uses of the CallOp return.
           // Replace all uses of the CallOp results so we can erase the CallOp.
           // This ToTensorOp must fold/DCE away or bufferization should be
           // considered failed.
-          Value toTensorOp =
-              b.create<bufferization::ToTensorOp>(callOp.getLoc(), buffer);
+          Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
+              callOp.getLoc(), buffer);
           oldRes.replaceAllUsesWith(toTensorOp);
           continue;
         }
@@ -588,7 +590,7 @@ struct CallOpInterface
 
       // Tensor operands are guaranteed to have been buferized.
       int64_t idx = opOperand.getOperandNumber();
-      Value buffer = state.lookupBuffer(tensorOperand);
+      Value buffer = state.lookupBuffer(rewriter, tensorOperand);
 
       // Caller / callee type mistmatch is handled with a CastOp.
       auto memRefType = bufferizedFuncType.getInput(idx);
@@ -598,16 +600,16 @@ struct CallOpInterface
       // that will either canonicalize away or fail compilation until we can do
       // something better.
       if (buffer.getType() != memRefType) {
-        Value castBuffer =
-            b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
+        Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
+                                                           memRefType, buffer);
         buffer = castBuffer;
       }
       newOperands.push_back(buffer);
     }
 
     // 4. Create the new CallOp.
-    Operation *newCallOp = b.create<CallOp>(callOp.getLoc(), funcOp.sym_name(),
-                                            resultTypes, newOperands);
+    Operation *newCallOp = rewriter.create<CallOp>(
+        callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
     newCallOp->setAttrs(callOp->getAttrs());
 
     // 5. Delete the op at the end of bufferization.
@@ -635,7 +637,7 @@ struct ReturnOpInterface
     return OpResult();
   }
 
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto returnOp = cast<ReturnOp>(op);
     assert(isa<FuncOp>(returnOp->getParentOp()) &&
@@ -645,9 +647,9 @@ struct ReturnOpInterface
       auto tensorType = operand.get().getType().dyn_cast<TensorType>();
       if (!tensorType)
         continue;
-      Value v = state.lookupBuffer(operand.get());
-      Value returnTensor = b.create<bufferization::ToTensorOp>(
-          returnOp.getLoc(), v);
+      Value v = state.lookupBuffer(rewriter, operand.get());
+      Value returnTensor =
+          rewriter.create<bufferization::ToTensorOp>(returnOp.getLoc(), v);
       operand.set(returnTensor);
     }
     return success();
@@ -656,12 +658,12 @@ struct ReturnOpInterface
 
 struct FuncOpInterface
     : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto funcOp = cast<FuncOp>(op);
 
     // Bufferize function body.
-    return comprehensive_bufferize::bufferize(&funcOp.body(), state);
+    return comprehensive_bufferize::bufferize(rewriter, &funcOp.body(), state);
   }
 
   /// Return `true` if the given function argument is writable.
@@ -726,7 +728,7 @@ static void annotateOpsWithBufferizationMarkers(FuncOp funcOp,
 LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     ModuleOp moduleOp, std::unique_ptr<BufferizationOptions> options) {
   IRRewriter rewriter(moduleOp.getContext());
-  BufferizationState state(moduleOp, *options, rewriter);
+  BufferizationState state(moduleOp, *options);
   ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
   BufferizationAliasInfo &aliasInfo = state.aliasInfo;
 
@@ -766,7 +768,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
   for (FuncOp funcOp : moduleState.orderedFuncOps) {
     // Note: It would be good to apply cleanups here but we cannot as aliasInfo
     // would be invalidated.
-    if (failed(bufferizeFuncOpBoundary(funcOp, state)))
+    if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, state)))
       return failure();
 
     if (!options->allowReturnMemref &&

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 4b5eb1848ff72..d008607ed4c1a 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -70,8 +70,8 @@ struct ExecuteRegionOpInterface
     if (hasTensorReturnType)
       return op->emitError(
           "scf.execute_region with tensor result not supported");
-    return comprehensive_bufferize::bufferize(&executeRegionOp.getRegion(),
-                                              state);
+    return comprehensive_bufferize::bufferize(
+        rewriter, &executeRegionOp.getRegion(), state);
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
@@ -194,12 +194,14 @@ struct IfOpInterface
     }
 
     // Replace op results.
-    state.replaceOp(op, newIfOp->getResults());
+    state.replaceOp(rewriter, op, newIfOp->getResults());
 
     // Bufferize then/else blocks.
-    if (failed(comprehensive_bufferize::bufferize(newIfOp.thenBlock(), state)))
+    if (failed(comprehensive_bufferize::bufferize(rewriter, newIfOp.thenBlock(),
+                                                  state)))
       return failure();
-    if (failed(comprehensive_bufferize::bufferize(newIfOp.elseBlock(), state)))
+    if (failed(comprehensive_bufferize::bufferize(rewriter, newIfOp.elseBlock(),
+                                                  state)))
       return failure();
 
     return success();
@@ -299,7 +301,7 @@ struct ForOpInterface
     // Construct a new scf.for op with memref instead of tensor values.
     SmallVector<Value> initArgs =
         convert(forOp.getInitArgs(), [&](Value val, int64_t index) {
-          return state.getResultBuffer(forOp->getOpResult(index));
+          return state.getResultBuffer(rewriter, forOp->getOpResult(index));
         });
     auto newForOp = rewriter.create<scf::ForOp>(
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
@@ -333,10 +335,10 @@ struct ForOpInterface
     yieldOp.getResultsMutable().assign(yieldValues);
 
     // Replace loop results.
-    state.replaceOp(op, newForOp->getResults());
+    state.replaceOp(rewriter, op, newForOp->getResults());
 
     // Bufferize loop body.
-    if (failed(comprehensive_bufferize::bufferize(loopBody, state)))
+    if (failed(comprehensive_bufferize::bufferize(rewriter, loopBody, state)))
       return failure();
 
     return success();

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index c837986cdeb29..0f91e52a5227e 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -65,7 +65,7 @@ struct CastOpInterface
                           BufferizationState &state) const {
     auto castOp = cast<tensor::CastOp>(op);
 
-    Value resultBuffer = state.getResultBuffer(castOp->getResult(0));
+    Value resultBuffer = state.getResultBuffer(rewriter, castOp->getResult(0));
     if (!resultBuffer)
       return failure();
     Type sourceType = resultBuffer.getType();
@@ -111,7 +111,7 @@ struct DimOpInterface
     auto dimOp = cast<tensor::DimOp>(op);
     if (!dimOp.source().getType().isa<RankedTensorType>())
       return dimOp.emitError("unranked tensor not supported");
-    Value v = state.lookupBuffer(dimOp.source());
+    Value v = state.lookupBuffer(rewriter, dimOp.source());
     state.replaceOpWithNewOp<memref::DimOp>(rewriter, op, v, dimOp.index());
     return success();
   }
@@ -147,7 +147,7 @@ struct ExtractSliceOpInterface
                           BufferizationState &state) const {
     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
     Location loc = extractSliceOp.getLoc();
-    Value srcMemref = state.lookupBuffer(extractSliceOp.source());
+    Value srcMemref = state.lookupBuffer(rewriter, extractSliceOp.source());
     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
     auto dstTensorType =
         extractSliceOp.result().getType().cast<RankedTensorType>();
@@ -178,7 +178,7 @@ struct ExtractSliceOpInterface
       subView = alloc;
     }
 
-    state.replaceOp(op, subView);
+    state.replaceOp(rewriter, op, subView);
     return success();
   }
 };
@@ -204,7 +204,7 @@ struct ExtractOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto extractOp = cast<tensor::ExtractOp>(op);
-    Value srcMemref = state.lookupBuffer(extractOp.tensor());
+    Value srcMemref = state.lookupBuffer(rewriter, extractOp.tensor());
     state.replaceOpWithNewOp<memref::LoadOp>(rewriter, op, srcMemref,
                                              extractOp.indices());
     return success();
@@ -241,10 +241,11 @@ struct InsertOpInterface
                           BufferizationState &state) const {
     auto insertOp = cast<tensor::InsertOp>(op);
     Location loc = insertOp.getLoc();
-    Value destMemref = state.getResultBuffer(insertOp->getOpResult(0));
+    Value destMemref =
+        state.getResultBuffer(rewriter, insertOp->getOpResult(0));
     rewriter.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
                                      insertOp.indices());
-    state.replaceOp(op, destMemref);
+    state.replaceOp(rewriter, op, destMemref);
     return success();
   }
 
@@ -421,7 +422,8 @@ struct InsertSliceOpInterface
     TensorBufferizationState &tensorState = getTensorBufferizationState(state);
 
     // When bufferizing out-of-place, `getResultBuffer` allocates.
-    Value dstMemref = state.getResultBuffer(insertSliceOp->getResult(0));
+    Value dstMemref =
+        state.getResultBuffer(rewriter, insertSliceOp->getResult(0));
     if (!dstMemref)
       return failure();
 
@@ -440,11 +442,11 @@ struct InsertSliceOpInterface
           loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
           insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
       // Copy tensor.
-      Value srcMemref = state.lookupBuffer(insertSliceOp.source());
+      Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source());
       state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView);
     }
 
-    state.replaceOp(op, dstMemref);
+    state.replaceOp(rewriter, op, dstMemref);
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index 73d89bc549fd8..c8d335e66bc8c 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -46,12 +46,12 @@ struct TransferReadOpInterface
            "only tensor types expected");
 
     // TransferReadOp always reads from the bufferized op.source().
-    Value buffer = state.lookupBuffer(readOp.source());
+    Value buffer = state.lookupBuffer(rewriter, readOp.source());
     Value read = rewriter.create<vector::TransferReadOp>(
         readOp.getLoc(), readOp.getVectorType(), buffer, readOp.indices(),
         readOp.permutation_map(), readOp.padding(), readOp.mask(),
         readOp.in_boundsAttr());
-    state.replaceOp(op, read);
+    state.replaceOp(rewriter, op, read);
     return success();
   }
 };
@@ -95,13 +95,13 @@ struct TransferWriteOpInterface
     // Create a new transfer_write on buffer that doesn't have a return value.
     // Leave the previous transfer_write to dead code as it still has uses at
     // this point.
-    Value resultBuffer = state.getResultBuffer(op->getResult(0));
+    Value resultBuffer = state.getResultBuffer(rewriter, op->getResult(0));
     if (!resultBuffer)
       return failure();
     rewriter.create<vector::TransferWriteOp>(
         writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
         writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
-    state.replaceOp(op, resultBuffer);
+    state.replaceOp(rewriter, op, resultBuffer);
 
     return success();
   }


        


More information about the Mlir-commits mailing list