[Mlir-commits] [mlir] 6c6bba7 - [mlir][linalg][bufferize][NFC] Use RewriterBase instead of OpBuilder

Matthias Springer llvmlistbot at llvm.org
Wed Jan 5 04:05:57 PST 2022


Author: Matthias Springer
Date: 2022-01-05T21:05:42+09:00
New Revision: 6c6bba743674c4f72dfd1adb89d44475a9b3cb88

URL: https://github.com/llvm/llvm-project/commit/6c6bba743674c4f72dfd1adb89d44475a9b3cb88
DIFF: https://github.com/llvm/llvm-project/commit/6c6bba743674c4f72dfd1adb89d44475a9b3cb88.diff

LOG: [mlir][linalg][bufferize][NFC] Use RewriterBase instead of OpBuilder

This is in preparation of unifying core bufferization and Comprehensive Bufferize.

Differential Revision: https://reviews.llvm.org/D116102

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index bda6c25b2877..cfafc6b33bb7 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -14,6 +14,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/EquivalenceClasses.h"
 #include "llvm/ADT/SetVector.h"
@@ -296,7 +297,8 @@ struct DialectBufferizationState {
 /// * `replaceOp` replaces an op with new values.
 class BufferizationState {
 public:
-  BufferizationState(Operation *op, const BufferizationOptions &options);
+  BufferizationState(Operation *op, const BufferizationOptions &options,
+                     RewriterBase &rewriter);
 
   // BufferizationState should be passed as a reference.
   BufferizationState(const BufferizationState &) = delete;
@@ -387,9 +389,10 @@ class BufferizationState {
   /// Replace an op with a new op. Tensor OpResults must be replaced with memref
   /// values.
   template <typename OpTy, typename... Args>
-  OpTy replaceOpWithNewOp(OpBuilder &b, Operation *op, Args &&...args) {
+  OpTy replaceOpWithNewOp(RewriterBase &rewriter, Operation *op,
+                          Args &&...args) {
     Operation *newOp =
-        b.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
+        rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
     replaceOp(op, newOp->getResults());
     return cast<OpTy>(newOp);
   }
@@ -417,8 +420,8 @@ class BufferizationState {
   /// Return a reference to the BufferizationOptions.
   const BufferizationOptions &getOptions() const { return options; }
 
-  /// Return a reference to the OpBuilder.
-  OpBuilder &getBuilder() { return builder; }
+  /// Return a reference to the rewriter.
+  RewriterBase &getRewriter() { return rewriter; }
 
 private:
   friend LogicalResult
@@ -440,7 +443,7 @@ class BufferizationState {
   const BufferizationOptions &options;
 
   /// The OpBuilder used during bufferization.
-  OpBuilder builder;
+  RewriterBase &rewriter;
 };
 
 /// Bufferize all ops in the given region.
@@ -523,7 +526,7 @@ struct AllocationHoistingBarrierOnly
     return false;
   }
 
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
     if (any_of(op->getOperandTypes(), isaTensor) ||

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
index df9090972bed..56c6b848c5f3 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
@@ -209,7 +209,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         }],
         /*retType=*/"LogicalResult",
         /*methodName=*/"bufferize",
-        /*args=*/(ins "OpBuilder &":$b,
+        /*args=*/(ins "RewriterBase &":$rewriter,
                       "BufferizationState &":$state),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
index e370d3f43042..e8d0fa984bb0 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
@@ -23,7 +23,7 @@ namespace arith_ext {
 struct ConstantOpInterface
     : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
                                                     arith::ConstantOp> {
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto constantOp = cast<arith::ConstantOp>(op);
     assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
@@ -35,8 +35,8 @@ struct ConstantOpInterface
 
     GlobalCreator globalCreator(moduleOp);
     auto globalMemref = globalCreator.getGlobalFor(constantOp);
-    state.replaceOpWithNewOp<memref::GetGlobalOp>(b, op, globalMemref.type(),
-                                                  globalMemref.getName());
+    state.replaceOpWithNewOp<memref::GetGlobalOp>(
+        rewriter, op, globalMemref.type(), globalMemref.getName());
     return success();
   }
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index f7d22251eadb..a639711196b4 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -333,8 +333,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
 }
 
 mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
-    Operation *op, const BufferizationOptions &options)
-    : aliasInfo(op), options(options), builder(op->getContext()) {
+    Operation *op, const BufferizationOptions &options, RewriterBase &rewriter)
+    : aliasInfo(op), options(options), rewriter(rewriter) {
   // Set up alias sets for OpResults that must bufferize in-place. This should
   // be done before making any other bufferization decisions.
   op->walk([&](BufferizableOpInterface bufferizableOp) {
@@ -361,7 +361,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
 /// bufferization is necessary.
 Value mlir::linalg::comprehensive_bufferize::BufferizationState::
     getResultBuffer(OpResult result) {
-  OpBuilder::InsertionGuard guard(builder);
+  OpBuilder::InsertionGuard guard(rewriter);
   Operation *op = result.getOwner();
   SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
   assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
@@ -391,9 +391,9 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
     Location loc = op->getLoc();
     // Move insertion point right after `operandBuffer`. That is where the
     // allocation should be inserted (in the absence of allocation hoisting).
-    setInsertionPointAfter(builder, operandBuffer);
+    setInsertionPointAfter(rewriter, operandBuffer);
     // Allocate the result buffer.
-    Value resultBuffer = createAllocDeallocPair(builder, loc, operandBuffer);
+    Value resultBuffer = createAllocDeallocPair(rewriter, loc, operandBuffer);
     bool skipCopy = false;
     // Do not copy if the last preceding write of `operand` is an op that does
     // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@@ -413,8 +413,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
       skipCopy = true;
     if (!skipCopy) {
       // The copy happens right before the op that is bufferized.
-      builder.setInsertionPoint(op);
-      createMemCpy(builder, loc, operandBuffer, resultBuffer);
+      rewriter.setInsertionPoint(op);
+      createMemCpy(rewriter, loc, operandBuffer, resultBuffer);
     }
     return resultBuffer;
   }
@@ -425,8 +425,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
 
 void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
     Operation *op, ValueRange values) {
-  OpBuilder &b = getBuilder();
-  OpBuilder::InsertionGuard g(b);
+  OpBuilder::InsertionGuard g(rewriter);
 
   // Replace all OpResults with the given values.
   for (OpResult opResult : op->getOpResults()) {
@@ -444,14 +443,14 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
       // The existing uses of the OpResult still expect a tensor. Insert a
       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
       // loose all of its users and eventually DCE away.
-      setInsertionPointAfter(b, replacement);
-      replacement = b.create<bufferization::ToTensorOp>(replacement.getLoc(),
-                                                        replacement);
+      setInsertionPointAfter(rewriter, replacement);
+      replacement = rewriter.create<bufferization::ToTensorOp>(
+          replacement.getLoc(), replacement);
     }
     opResult.replaceAllUsesWith(replacement);
   }
 
-  op->erase();
+  rewriter.eraseOp(op);
 }
 
 LogicalResult
@@ -481,7 +480,7 @@ mlir::linalg::comprehensive_bufferize::bufferize(Block *block,
 LogicalResult
 mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
                                                  BufferizationState &state) {
-  OpBuilder &b = state.getBuilder();
+  RewriterBase &rewriter = state.getRewriter();
 
   // Check if op has tensor results or operands.
   auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
@@ -496,8 +495,8 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
   // Bufferize using `BufferizableOpInterface`. Interface implementations are
   // responsible for bufferizing nested ops.
   if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {
-    b.setInsertionPoint(op);
-    return bufferizableOp.bufferize(b, state);
+    rewriter.setInsertionPoint(op);
+    return bufferizableOp.bufferize(rewriter, state);
   }
 
   // `op` is an unbufferizable tensor op.
@@ -679,10 +678,9 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
   }
 
   // Insert to_memref op.
-  OpBuilder &b = getBuilder();
-  OpBuilder::InsertionGuard g(b);
-  setInsertionPointAfter(b, tensor);
-  return b.create<bufferization::ToMemrefOp>(
+  OpBuilder::InsertionGuard g(rewriter);
+  setInsertionPointAfter(rewriter, tensor);
+  return rewriter.create<bufferization::ToMemrefOp>(
       tensor.getLoc(),
       getDynamicMemRefType(tensor.getType().cast<RankedTensorType>()), tensor);
 }

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
index 3419a6aa4492..eab925f02420 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
@@ -50,15 +50,14 @@ struct ToMemrefOpInterface
     return OpResult();
   }
 
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto toMemrefOp = cast<bufferization::ToMemrefOp>(op);
 
     // Fold to_memref(to_tensor(x)) to x.
     if (auto toTensorOp =
             toMemrefOp.tensor().getDefiningOp<bufferization::ToTensorOp>()) {
-      toMemrefOp.replaceAllUsesWith(toTensorOp.memref());
-      toMemrefOp.erase();
+      rewriter.replaceOp(toMemrefOp, toTensorOp.memref());
       return success();
     }
 
@@ -86,7 +85,7 @@ struct ToMemrefOpInterface
 struct ToTensorOpInterface
     : public BufferizableOpInterface::ExternalModel<ToTensorOpInterface,
                                                     bufferization::ToTensorOp> {
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     return success();
   }

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index c46746f6813f..66adbe7d1fc8 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -651,7 +651,8 @@ annotateOpsWithBufferizationMarkers(Operation *op,
 
 LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     Operation *op, std::unique_ptr<BufferizationOptions> options) {
-  BufferizationState state(op, *options);
+  IRRewriter rewriter(op->getContext());
+  BufferizationState state(op, *options, rewriter);
   return runComprehensiveBufferize(op, *options, state);
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 190f0fea5108..9977e46b6878 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -23,11 +23,11 @@ namespace {
 // TODO: Ops in the linalg dialect can directly implement this interface.
 
 /// Generic conversion for any LinalgOp on tensors.
-static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
+static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
                                        BufferizationState &state) {
   // Take a guard before anything else.
-  OpBuilder::InsertionGuard g(b);
-  b.setInsertionPoint(op);
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
 
   // Nothing to do. This op is already bufferized.
   if (op.hasBufferSemantics())
@@ -63,9 +63,9 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
   newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
 
   // Set insertion point now that potential alloc/dealloc are introduced.
-  b.setInsertionPoint(op);
-  auto bufferizedOp = cast<LinalgOp>(
-      op.clone(b, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands));
+  rewriter.setInsertionPoint(op);
+  auto bufferizedOp = cast<LinalgOp>(op.clone(
+      rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands));
 
   // Replace the results of the old op with the new output buffers.
   state.replaceOp(op, newOutputBuffers);
@@ -177,9 +177,9 @@ struct LinalgOpInterface
     return BufferRelation::Equivalent;
   }
 
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
-    return bufferizeLinalgOp(b, cast<LinalgOp>(op), state);
+    return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state);
   }
 };
 
@@ -192,7 +192,7 @@ struct InitTensorOpInterface
     return false;
   }
 
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto initTensorOp = cast<linalg::InitTensorOp>(op);
 
@@ -200,7 +200,7 @@ struct InitTensorOpInterface
     if (initTensorOp->getUses().empty())
       return success();
 
-    Value alloc = state.createAllocDeallocPair(b, initTensorOp->getLoc(),
+    Value alloc = state.createAllocDeallocPair(rewriter, initTensorOp->getLoc(),
                                                initTensorOp.result());
     state.replaceOp(op, alloc);
     return success();
@@ -251,15 +251,10 @@ struct TiledLoopOpInterface
 
   bool isAllocationHoistingBarrier(Operation *op) const { return true; }
 
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
 
-    // Use IRRewriter instead of OpBuilder because it has additional helper
-    // functions.
-    IRRewriter rewriter(op->getContext());
-    rewriter.setInsertionPoint(tiledLoopOp);
-
     // Compute new inputs, outputs and results.
     SmallVector<Value> newInputs, newOutputs, newResults;
     for (Value value : tiledLoopOp.inputs()) {
@@ -358,7 +353,7 @@ struct YieldOpInterface
     return OpResult();
   }
 
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto yieldOp = cast<linalg::YieldOp>(op);
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index e7a5330ef399..d622245718d6 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -725,7 +725,8 @@ static void annotateOpsWithBufferizationMarkers(FuncOp funcOp,
 
 LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     ModuleOp moduleOp, std::unique_ptr<BufferizationOptions> options) {
-  BufferizationState state(moduleOp, *options);
+  IRRewriter rewriter(moduleOp.getContext());
+  BufferizationState state(moduleOp, *options, rewriter);
   ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
   BufferizationAliasInfo &aliasInfo = state.aliasInfo;
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 5db5deb6aee6..4b5eb1848ff7 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -60,7 +60,7 @@ struct ExecuteRegionOpInterface
     return true;
   }
 
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     // TODO: Add bufferization support when needed. scf.execute_region should be
     // bufferized similar to scf.if.
@@ -135,15 +135,10 @@ struct IfOpInterface
     return true;
   }
 
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto ifOp = cast<scf::IfOp>(op);
 
-    // Use IRRewriter instead of OpBuilder because it has additional helper
-    // functions.
-    IRRewriter rewriter(op->getContext());
-    rewriter.setInsertionPoint(ifOp);
-
     // Compute new types of the bufferized scf.if op.
     SmallVector<Type> newTypes;
     for (Type returnType : ifOp->getResultTypes()) {
@@ -276,16 +271,11 @@ struct ForOpInterface
     return true;
   }
 
-  LogicalResult bufferize(Operation *op, OpBuilder & /*b*/,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto forOp = cast<scf::ForOp>(op);
     Block *oldLoopBody = &forOp.getLoopBody().front();
 
-    // Use IRRewriter instead of OpBuilder because it has additional helper
-    // functions.
-    IRRewriter rewriter(op->getContext());
-    rewriter.setInsertionPoint(forOp);
-
     // Indices of all iter_args that have tensor type. These are the ones that
     // are bufferized.
     DenseSet<int64_t> indices;
@@ -438,7 +428,7 @@ struct YieldOpInterface
     return OpResult();
   }
 
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto yieldOp = cast<scf::YieldOp>(op);
     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>(

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 30ca9ed0a78b..c837986cdeb2 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -61,7 +61,7 @@ struct CastOpInterface
     return BufferRelation::Equivalent;
   }
 
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto castOp = cast<tensor::CastOp>(op);
 
@@ -82,7 +82,8 @@ struct CastOpInterface
             : MemRefLayoutAttrInterface();
     Type memRefType = getContiguousOrUnrankedMemRefType(
         castOp.getResult().getType(), layout, memorySpace);
-    state.replaceOpWithNewOp<memref::CastOp>(b, op, memRefType, resultBuffer);
+    state.replaceOpWithNewOp<memref::CastOp>(rewriter, op, memRefType,
+                                             resultBuffer);
     return success();
   }
 };
@@ -105,13 +106,13 @@ struct DimOpInterface
     return OpResult();
   }
 
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto dimOp = cast<tensor::DimOp>(op);
     if (!dimOp.source().getType().isa<RankedTensorType>())
       return dimOp.emitError("unranked tensor not supported");
     Value v = state.lookupBuffer(dimOp.source());
-    state.replaceOpWithNewOp<memref::DimOp>(b, op, v, dimOp.index());
+    state.replaceOpWithNewOp<memref::DimOp>(rewriter, op, v, dimOp.index());
     return success();
   }
 };
@@ -142,7 +143,7 @@ struct ExtractSliceOpInterface
     return BufferRelation::None;
   }
 
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
     Location loc = extractSliceOp.getLoc();
@@ -155,7 +156,8 @@ struct ExtractSliceOpInterface
     bool inplace = state.isInPlace(extractSliceOp->getResult(0));
     Value alloc;
     if (!inplace)
-      alloc = state.createAllocDeallocPair(b, loc, extractSliceOp.result());
+      alloc =
+          state.createAllocDeallocPair(rewriter, loc, extractSliceOp.result());
 
     // Bufferize to subview.
     auto subviewMemRefType =
@@ -164,7 +166,7 @@ struct ExtractSliceOpInterface
             extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
             extractSliceOp.getMixedStrides())
             .cast<MemRefType>();
-    Value subView = b.create<memref::SubViewOp>(
+    Value subView = rewriter.create<memref::SubViewOp>(
         loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
         extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
 
@@ -172,7 +174,7 @@ struct ExtractSliceOpInterface
     if (!inplace) {
       // Do not copy if the copied data is never read.
       if (state.isValueRead(extractSliceOp.result()))
-        state.createMemCpy(b, extractSliceOp.getLoc(), subView, alloc);
+        state.createMemCpy(rewriter, extractSliceOp.getLoc(), subView, alloc);
       subView = alloc;
     }
 
@@ -199,11 +201,11 @@ struct ExtractOpInterface
     return OpResult();
   }
 
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto extractOp = cast<tensor::ExtractOp>(op);
     Value srcMemref = state.lookupBuffer(extractOp.tensor());
-    state.replaceOpWithNewOp<memref::LoadOp>(b, op, srcMemref,
+    state.replaceOpWithNewOp<memref::LoadOp>(rewriter, op, srcMemref,
                                              extractOp.indices());
     return success();
   }
@@ -235,13 +237,13 @@ struct InsertOpInterface
     return {&op->getOpOperand(1) /*dest*/};
   }
 
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto insertOp = cast<tensor::InsertOp>(op);
     Location loc = insertOp.getLoc();
     Value destMemref = state.getResultBuffer(insertOp->getOpResult(0));
-    b.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
-                              insertOp.indices());
+    rewriter.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
+                                     insertOp.indices());
     state.replaceOp(op, destMemref);
     return success();
   }
@@ -407,7 +409,7 @@ struct InsertSliceOpInterface
     return false;
   }
 
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     // insert_slice ops arise from tiling and bufferizing them out-of-place is
     // generally a deal breaker. When used with loops, this ends up cloning the
@@ -434,12 +436,12 @@ struct InsertSliceOpInterface
               insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
               insertSliceOp.getMixedStrides())
               .cast<MemRefType>();
-      Value subView = b.create<memref::SubViewOp>(
+      Value subView = rewriter.create<memref::SubViewOp>(
           loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
           insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
       // Copy tensor.
       Value srcMemref = state.lookupBuffer(insertSliceOp.source());
-      state.createMemCpy(b, insertSliceOp.getLoc(), srcMemref, subView);
+      state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView);
     }
 
     state.replaceOp(op, dstMemref);

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index 50ceb5aa77c9..73d89bc549fd 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -39,7 +39,7 @@ struct TransferReadOpInterface
     return OpResult();
   }
 
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto readOp = cast<vector::TransferReadOp>(op);
     assert(readOp.getShapedType().isa<TensorType>() &&
@@ -47,7 +47,7 @@ struct TransferReadOpInterface
 
     // TransferReadOp always reads from the bufferized op.source().
     Value buffer = state.lookupBuffer(readOp.source());
-    Value read = b.create<vector::TransferReadOp>(
+    Value read = rewriter.create<vector::TransferReadOp>(
         readOp.getLoc(), readOp.getVectorType(), buffer, readOp.indices(),
         readOp.permutation_map(), readOp.padding(), readOp.mask(),
         readOp.in_boundsAttr());
@@ -86,7 +86,7 @@ struct TransferWriteOpInterface
     return BufferRelation::Equivalent;
   }
 
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto writeOp = cast<vector::TransferWriteOp>(op);
     assert(writeOp.getShapedType().isa<TensorType>() &&
@@ -98,7 +98,7 @@ struct TransferWriteOpInterface
     Value resultBuffer = state.getResultBuffer(op->getResult(0));
     if (!resultBuffer)
       return failure();
-    b.create<vector::TransferWriteOp>(
+    rewriter.create<vector::TransferWriteOp>(
         writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
         writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
     state.replaceOp(op, resultBuffer);


        


More information about the Mlir-commits mailing list