[Mlir-commits] [mlir] b55d55e - [mlir][bufferize][NFC] Remove BufferizationState

Matthias Springer llvmlistbot at llvm.org
Fri Jun 17 05:10:14 PDT 2022


Author: Matthias Springer
Date: 2022-06-17T14:04:11+02:00
New Revision: b55d55ecd9b2ce99b98bbb2595a1feb957d02a28

URL: https://github.com/llvm/llvm-project/commit/b55d55ecd9b2ce99b98bbb2595a1feb957d02a28
DIFF: https://github.com/llvm/llvm-project/commit/b55d55ecd9b2ce99b98bbb2595a1feb957d02a28.diff

LOG: [mlir][bufferize][NFC] Remove BufferizationState

With the recent refactorings, this class is no longer needed. We can use BufferizationOptions in all places were BufferizationState was used.

Differential Revision: https://reviews.llvm.org/D127653

Added: 
    

Modified: 
    mlir/docs/Bufferization.md
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
    mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
    mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
    mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Bufferization.md b/mlir/docs/Bufferization.md
index 8317f3e87fbff..7d13e9d22eab8 100644
--- a/mlir/docs/Bufferization.md
+++ b/mlir/docs/Bufferization.md
@@ -30,40 +30,38 @@ and with aggressive in-place bufferization.
 
 One-Shot Bufferize is:
 
-* **Monolithic**: A single MLIR pass does the entire
-work, whereas the previous bufferization in MLIR was split across multiple
-passes residing in 
diff erent dialects. In One-Shot Bufferize,
-`BufferizableOpInterface` implementations are spread across 
diff erent dialects.
-
-* A **whole-function at a time analysis**. In-place bufferization decisions are
-made by analyzing SSA use-def chains on tensors. Op interface implementations
-not only provide the rewrite logic from tensor ops to memref ops, but also
-helper methods for One-Shot Bufferize's analysis to query information about an
-op's bufferization/memory semantics.
-
-* **Extensible** via an op interface: All
-ops that implement `BufferizableOpInterface` can be bufferized.
-
-* **2-Pass**:
-Bufferization is internally broken down into 2 steps: First, analyze the entire
-IR and make bufferization decisions. Then, bufferize (rewrite) the IR. The
-analysis has access to exact SSA use-def information. It incrementally builds
-alias and equivalence sets and does not rely on a posteriori-alias analysis from
-preallocated memory.
-
-* **Greedy**: Operations are analyzed one-by-one and it is
-decided on the spot whether a tensor OpOperand must be copied or not. Heuristics
-determine the order of analysis.
-
-* **Modular**: The current One-Shot Analysis
-can be replaced with a 
diff erent analysis. The result of the analysis are
-queried by the bufferization via `BufferizationState`, in particular
-`BufferizationState::isInPlace`. Any derived class of `BufferizationState` that
-implements a small number virtual functions can serve as a custom analysis. It
-is even possible to run One-Shot Bufferize without any analysis
-(`AlwaysCopyBufferizationState`), in which case One-Shot Bufferize behaves
-exactly like the old dialect conversion-based bufferization (i.e., copy every
-buffer before writing to it).
+*   **Monolithic**: A single MLIR pass does the entire work, whereas the
+    previous bufferization in MLIR was split across multiple passes residing in
+    
diff erent dialects. In One-Shot Bufferize, `BufferizableOpInterface`
+    implementations are spread across 
diff erent dialects.
+
+*   A **whole-function at a time analysis**. In-place bufferization decisions
+    are made by analyzing SSA use-def chains on tensors. Op interface
+    implementations not only provide the rewrite logic from tensor ops to memref
+    ops, but also helper methods for One-Shot Bufferize's analysis to query
+    information about an op's bufferization/memory semantics.
+
+*   **Extensible** via an op interface: All ops that implement
+    `BufferizableOpInterface` can be bufferized.
+
+*   **2-Pass**: Bufferization is internally broken down into 2 steps: First,
+    analyze the entire IR and make bufferization decisions. Then, bufferize
+    (rewrite) the IR. The analysis has access to exact SSA use-def information.
+    It incrementally builds alias and equivalence sets and does not rely on a
+    posteriori-alias analysis from preallocated memory.
+
+*   **Greedy**: Operations are analyzed one-by-one and it is decided on the spot
+    whether a tensor OpOperand must be copied or not. Heuristics determine the
+    order of analysis.
+
+*   **Modular**: The current One-Shot Analysis can be replaced with a 
diff erent
+    analysis. The result of the analysis are queried by the bufferization via
+    `AnalysisState`, in particular `AnalysisState::isInPlace`. Any derived class
+    of `AnalysisState` that implements a small number virtual functions can
+    serve as a custom analysis. It is even possible to run One-Shot Bufferize
+    without any analysis (`AlwaysCopyAnalysisState`), in which case One-Shot
+    Bufferize behaves exactly like the old dialect conversion-based
+    bufferization (i.e., copy every buffer before writing to it).
 
 To reduce complexity, One-Shot Bufferize should be
 [run after other transformations](https://llvm.discourse.group/t/rfc-linalg-on-tensors-update-and-comprehensive-bufferization-rfc/3373),

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 3cd9d70138d11..fa44fde98b6ed 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -236,7 +236,7 @@ struct BufferizationOptions {
   ///
   /// Note: Deactivating this flag can lead to incorrect bufferization results
   /// when used incorrectly. This flag is useful with
-  /// `AlwaysCopyBufferizationState` which bufferizes all writing tensor
+  /// `AlwaysCopyAnalysisState` which bufferizes all writing tensor
   /// OpOperands out-of-place.
   bool enforceAliasingInvariants = true;
 
@@ -464,33 +464,6 @@ class AnalysisState {
   const BufferizationOptions &options;
 };
 
-/// BufferizationState provides helper functions for performing bufferization
-/// rewrites and handling memref buffers.
-struct BufferizationState {
-  BufferizationState(const BufferizationOptions &options) : options(options) {}
-
-  /// Lookup the buffer for the given value. If the value was not bufferized
-  /// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,
-  /// from which the memref operand is returned.
-  Value getBuffer(RewriterBase &rewriter, Value value);
-
-  /// Return the buffer type for a given Value (tensor) after bufferization.
-  ///
-  /// Note: Op implementations should preferrably call `getBuffer()->getType()`.
-  /// This function should only be used if `getBuffer` cannot be used.
-  BaseMemRefType getBufferType(Value value) const;
-
-  /// Return a reference to the BufferizationOptions.
-  const BufferizationOptions &getOptions() const { return options; }
-
-protected:
-  // BufferizationState should be passed as a reference.
-  BufferizationState(const BufferizationState &) = delete;
-
-private:
-  const BufferizationOptions &options;
-};
-
 /// Create an AllocTensorOp for the given shaped value (memref or tensor).
 /// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
 /// undefined contents is allocated.
@@ -498,6 +471,18 @@ Value allocateTensorForShapedValue(OpBuilder &b, Location loc,
                                    Value shapedValue, bool escape,
                                    bool copy = true);
 
+/// Lookup the buffer for the given value. If the value was not bufferized
+/// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,
+/// from which the memref operand is returned.
+Value getBuffer(RewriterBase &rewriter, Value value,
+                const BufferizationOptions &options);
+
+/// Return the buffer type for a given Value (tensor) after bufferization.
+///
+/// Note: Op implementations should preferrably call `getBuffer()->getType()`.
+/// This function should only be used if `getBuffer` cannot be used.
+BaseMemRefType getBufferType(Value value, const BufferizationOptions &options);
+
 /// Replace an op with replacement values. The op is deleted. Tensor OpResults
 /// must be replaced with memref values.
 void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index e47e8478d4ad7..e550b900cb8ac 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -221,7 +221,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
       InterfaceMethod<
         /*desc=*/[{
           Bufferize this op, i.e., rewrite it into a memref-based equivalent.
-          Buffers of tensor SSA values can be retrieved via `state.getBuffer`.
+          Buffers of tensor SSA values can be retrieved via `getBuffer`.
           Uses of tensor results of the existing tensor op can be replaced with
           `replaceOpWithBufferizedValues` or `replaceOpWithNewBufferizedOp`.
           These two functions automatically handle the tensor-to-memref type
@@ -233,12 +233,6 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           a) A buffer that aliases one of buffers in getAliasingOpOperand(r).
           b) Or: A newly allocated buffer.
 
-          Regions of an op should be inlined into the new op instead of cloning
-          them. This is not only more efficient, but also necessary so that no
-          analysis results are lost. (Bufferization decisions are tracked via
-          OpOperand pointers and cloned ops have new OpOperands.) If regions are
-          cloned instead of inlined, additional buffer copies may be inserted.
-
           This method will never be called on ops that do not have at least one
           tensor operand/result.
 
@@ -252,7 +246,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*retType=*/"LogicalResult",
         /*methodName=*/"bufferize",
         /*args=*/(ins "RewriterBase &":$rewriter,
-                      "BufferizationState &":$state),
+                      "const BufferizationOptions &":$options),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           llvm_unreachable("bufferize not implemented");

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index a0509767cfed8..93154f48c32ec 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -71,7 +71,8 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
   let results = (outs AnyTensor:$result);
 
   let extraClassDeclaration = [{
-    LogicalResult bufferize(RewriterBase &rewriter, BufferizationState &state);
+    LogicalResult bufferize(RewriterBase &rewriter,
+                            const BufferizationOptions &options);
 
     bool isMemoryWrite(OpResult opResult, const AnalysisState &state);
 
@@ -242,7 +243,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     // results as not writable enforces a buffer copy and has the same effect.
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            BufferizationState &state) const {
+                            const BufferizationOptions &options) const {
       // to_tensor cannot be bufferized. However, other ops that are using
       // to_tensor's result will eventually be bufferized. At that point, they
       // will start using to_tensor's memref operand. Once all users of
@@ -334,7 +335,7 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
     }
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
   }];
 
   let assemblyFormat = "$tensor attr-dict `:` type($memref)";

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 92cb346b265f8..a2b7f7f5017d7 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -25,7 +25,6 @@ namespace mlir {
 namespace bufferization {
 
 class AnalysisState;
-struct BufferizationState;
 struct BufferizationOptions;
 class OpFilter;
 

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
index 63a1e07d16a22..782b2c4aeeda0 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
@@ -15,7 +15,6 @@ struct LogicalResult;
 class ModuleOp;
 
 namespace bufferization {
-struct BufferizationState;
 class OneShotAnalysisState;
 struct OneShotBufferizationOptions;
 

diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
index 7bd89762e8a41..b73c9039abcf6 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -23,7 +23,7 @@ struct ConstantOpInterface
     : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
                                                     arith::ConstantOp> {
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto constantOp = cast<arith::ConstantOp>(op);
 
     // Only ranked tensors are supported.
@@ -38,7 +38,7 @@ struct ConstantOpInterface
     // Create global memory segment and replace tensor with memref pointing to
     // that memory segment.
     FailureOr<memref::GlobalOp> globalOp =
-        getGlobalFor(constantOp, state.getOptions().bufferAlignment);
+        getGlobalFor(constantOp, options.bufferAlignment);
     if (failed(globalOp))
       return failure();
     memref::GlobalOp globalMemref = globalOp.getValue();
@@ -80,11 +80,11 @@ struct IndexCastOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto castOp = cast<arith::IndexCastOp>(op);
     auto resultTensorType = castOp.getType().cast<TensorType>();
 
-    Value source = state.getBuffer(rewriter, castOp.getIn());
+    Value source = getBuffer(rewriter, castOp.getIn(), options);
     auto sourceType = source.getType().cast<BaseMemRefType>();
 
     // Result type should have same layout and address space as the source type.
@@ -132,7 +132,7 @@ struct SelectOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto selectOp = cast<arith::SelectOp>(op);
     Location loc = selectOp.getLoc();
 
@@ -140,8 +140,8 @@ struct SelectOpInterface
     // instead of its OpOperands. In the worst case, 2 copies are inserted at
     // the moment (one for each tensor). When copying the op result, only one
     // copy would be needed.
-    Value trueBuffer = state.getBuffer(rewriter, selectOp.getTrueValue());
-    Value falseBuffer = state.getBuffer(rewriter, selectOp.getFalseValue());
+    Value trueBuffer = getBuffer(rewriter, selectOp.getTrueValue(), options);
+    Value falseBuffer = getBuffer(rewriter, selectOp.getFalseValue(), options);
 
     // The "true" and the "false" operands must have the same type. If the
     // buffers have 
diff erent types, they 
diff er only in their layout map. Cast

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 8279d7efce65d..3e97ecdcbfff2 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -477,7 +477,8 @@ static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
 #endif
 }
 
-Value BufferizationState::getBuffer(RewriterBase &rewriter, Value value) {
+Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
+                               const BufferizationOptions &options) {
   auto tensorType = value.getType().dyn_cast<TensorType>();
   assert(tensorType && "unexpected non-tensor type");
 
@@ -488,21 +489,22 @@ Value BufferizationState::getBuffer(RewriterBase &rewriter, Value value) {
   // Insert to_memref op.
   OpBuilder::InsertionGuard g(rewriter);
   setInsertionPointAfter(rewriter, value);
-  Type memrefType = getMemRefType(tensorType, getOptions());
+  Type memrefType = getMemRefType(tensorType, options);
   ensureToMemrefOpIsValid(value, memrefType);
   return rewriter.create<bufferization::ToMemrefOp>(value.getLoc(), memrefType,
                                                     value);
 }
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
-BaseMemRefType BufferizationState::getBufferType(Value value) const {
+BaseMemRefType
+bufferization::getBufferType(Value value, const BufferizationOptions &options) {
   auto tensorType = value.getType().dyn_cast<TensorType>();
   assert(tensorType && "unexpected non-tensor type");
 
   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
     return toTensorOp.memref().getType().cast<BaseMemRefType>();
 
-  return getMemRefType(tensorType, getOptions());
+  return getMemRefType(tensorType, options);
 }
 
 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index ad73f9da70fff..1b59a09280b20 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -150,7 +150,7 @@ void mlir::bufferization::populateDynamicDimSizes(
 //===----------------------------------------------------------------------===//
 
 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
-                                       BufferizationState &state) {
+                                       const BufferizationOptions &options) {
   OpBuilder::InsertionGuard g(rewriter);
   Location loc = getLoc();
 
@@ -163,7 +163,7 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
   // Create buffer allocation.
   Value copyBuffer;
   if (copy())
-    copyBuffer = state.getBuffer(rewriter, copy());
+    copyBuffer = getBuffer(rewriter, copy(), options);
   auto allocType =
       MemRefType::get(getType().getShape(), getType().getElementType());
   SmallVector<Value> dynamicDims = dynamicSizes();
@@ -172,25 +172,24 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
     populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
   }
   FailureOr<Value> alloc =
-      state.getOptions().createAlloc(rewriter, loc, allocType, dynamicDims);
+      options.createAlloc(rewriter, loc, allocType, dynamicDims);
   if (failed(alloc))
     return failure();
 
   // Create memory copy (if any).
   if (copy()) {
-    if (failed(
-            state.getOptions().createMemCpy(rewriter, loc, copyBuffer, *alloc)))
+    if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
       return failure();
   }
 
   // Should the buffer be deallocated?
-  AnalysisState analysisState(state.getOptions());
+  AnalysisState analysisState(options);
   bool dealloc;
   if (escape().hasValue()) {
     dealloc = !*escape();
   } else {
     // No "escape" annotation found.
-    if (state.getOptions().createDeallocs) {
+    if (options.createDeallocs) {
       // Perform an ad-hoc analysis.
       dealloc = !analysisState.isTensorYielded(getResult());
     } else {
@@ -206,7 +205,7 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
     return success();
 
   rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
-  if (failed(state.getOptions().createDealloc(rewriter, loc, *alloc)))
+  if (failed(options.createDealloc(rewriter, loc, *alloc)))
     return failure();
   return success();
 }
@@ -627,7 +626,7 @@ void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
-                                    BufferizationState &state) {
+                                    const BufferizationOptions &options) {
   // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
   (void)foldToMemrefToTensorPair(rewriter, *this);
   // Note: The return value of `bufferize` indicates whether there was an error

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 8f4d2066e092e..dd096d0f7f967 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -401,7 +401,6 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
   DenseSet<Operation *> erasedOps;
 
   // Bufferize all ops.
-  BufferizationState bufferizationState(options);
   BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
                                  worklist, options, opFilter);
   for (unsigned i = 0; i < worklist.size(); ++i) {
@@ -420,7 +419,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
       continue;
     // Bufferize the op.
     rewriter.setInsertionPoint(op);
-    if (failed(bufferizableOp.bufferize(rewriter, bufferizationState)))
+    if (failed(bufferizableOp.bufferize(rewriter, options)))
       return op->emitError("failed to bufferize op");
   }
 
@@ -433,7 +432,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
 
   /// Check the result of bufferization. Return an error if an op was not
   /// bufferized, unless partial bufferization is allowed.
-  if (bufferizationState.getOptions().allowUnknownOps)
+  if (options.allowUnknownOps)
     return success();
 
   for (Operation *op : worklist) {

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index e75338594d5bb..6805e76ca435c 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -258,7 +258,7 @@ struct CallOpInterface
   /// All function arguments are writable. It is the responsibility of the
   /// CallOp to insert buffer copies where necessary.
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     func::CallOp callOp = cast<func::CallOp>(op);
     unsigned numResults = callOp.getNumResults();
     unsigned numOperands = callOp->getNumOperands();
@@ -307,7 +307,7 @@ struct CallOpInterface
       // Retrieve buffers for tensor operands.
       Value buffer = newOperands[idx];
       if (!buffer)
-        buffer = state.getBuffer(rewriter, opOperand.get());
+        buffer = getBuffer(rewriter, opOperand.get(), options);
 
       // Caller / callee type mismatch is handled with a CastOp.
       auto memRefType = funcType.getInput(idx);
@@ -364,7 +364,7 @@ struct ReturnOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
 #ifndef NDEBUG
     auto returnOp = cast<func::ReturnOp>(op);
     assert(isa<FuncOp>(returnOp->getParentOp()) &&
@@ -386,11 +386,9 @@ struct FuncOpInterface
   /// All function bbArgs are writable unless they are explicitly marked as
   /// read-only. Callers must insert copies when needed.
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto funcOp = cast<FuncOp>(op);
     FunctionType funcType = funcOp.getFunctionType();
-    const OneShotBufferizationOptions &options =
-        static_cast<const OneShotBufferizationOptions &>(state.getOptions());
 
     // Construct the bufferized function type.
     SmallVector<Type> argTypes;

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 773fa47692449..bb6f0532af705 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -429,7 +429,6 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
   assert(options.bufferizeFunctionBoundaries &&
          "expected that function boundary bufferization is activated");
   IRRewriter rewriter(moduleOp.getContext());
-  BufferizationState bufferizationState(options);
 
   // A list of functions in the order in which they are analyzed + bufferized.
   SmallVector<func::FuncOp> orderedFuncOps;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 3ecab39cce61b..cc27b4403d898 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -20,11 +20,9 @@ using namespace mlir::bufferization;
 
 namespace {
 
-// TODO: Ops in the linalg dialect can directly implement this interface.
-
 /// Generic conversion for any LinalgOp on tensors.
 static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
-                                       BufferizationState &state) {
+                                       const BufferizationOptions &options) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(op);
@@ -46,14 +44,14 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
       newInputBuffers.push_back(opOperand->get());
       continue;
     }
-    newInputBuffers.push_back(state.getBuffer(rewriter, opOperand->get()));
+    newInputBuffers.push_back(getBuffer(rewriter, opOperand->get(), options));
   }
 
   // New output operands for the cloned op.
   SmallVector<Value> newOutputBuffers;
   for (OpResult opResult : op->getOpResults()) {
     OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber());
-    Value resultBuffer = state.getBuffer(rewriter, opOperand->get());
+    Value resultBuffer = getBuffer(rewriter, opOperand->get(), options);
     newOutputBuffers.push_back(resultBuffer);
   }
 
@@ -123,8 +121,8 @@ struct LinalgOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
-    return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state);
+                          const BufferizationOptions &options) const {
+    return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), options);
   }
 };
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 55b6518028847..4bff6b56e240c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -73,7 +73,7 @@ struct ExecuteRegionOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
 
     // Compute new result types.
@@ -81,7 +81,7 @@ struct ExecuteRegionOpInterface
     for (Type type : executeRegionOp->getResultTypes()) {
       if (auto tensorType = type.dyn_cast<TensorType>()) {
         // TODO: Infer the result type instead of computing it.
-        newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
+        newResultTypes.push_back(getMemRefType(tensorType, options));
       } else {
         newResultTypes.push_back(type);
       }
@@ -183,7 +183,7 @@ struct IfOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto ifOp = cast<scf::IfOp>(op);
 
     // Compute new types of the bufferized scf.if op.
@@ -191,7 +191,7 @@ struct IfOpInterface
     for (Type returnType : ifOp->getResultTypes()) {
       if (auto tensorType = returnType.dyn_cast<TensorType>()) {
         // TODO: Infer the result type instead of computing it.
-        newTypes.push_back(getMemRefType(tensorType, state.getOptions()));
+        newTypes.push_back(getMemRefType(tensorType, options));
       } else {
         newTypes.push_back(returnType);
       }
@@ -309,11 +309,11 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
 /// given OpOperands. If an operand is not a tensor, return the original value.
 static SmallVector<Value> getBuffers(RewriterBase &rewriter,
                                      MutableArrayRef<OpOperand> operands,
-                                     BufferizationState &state) {
+                                     const BufferizationOptions &options) {
   SmallVector<Value> result;
   for (OpOperand &opOperand : operands) {
     if (opOperand.get().getType().isa<TensorType>()) {
-      Value resultBuffer = state.getBuffer(rewriter, opOperand.get());
+      Value resultBuffer = getBuffer(rewriter, opOperand.get(), options);
       result.push_back(resultBuffer);
     } else {
       result.push_back(opOperand.get());
@@ -325,10 +325,11 @@ static SmallVector<Value> getBuffers(RewriterBase &rewriter,
 /// Helper function for loop bufferization. Compute the buffer that should be
 /// yielded from a loop block (loop body or loop condition).
 static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor,
-                              BaseMemRefType type, BufferizationState &state) {
+                              BaseMemRefType type,
+                              const BufferizationOptions &options) {
   assert(tensor.getType().isa<TensorType>() && "expected tensor");
   ensureToMemrefOpIsValid(tensor, type);
-  Value yieldedVal = state.getBuffer(rewriter, tensor);
+  Value yieldedVal = getBuffer(rewriter, tensor, options);
   return castBuffer(rewriter, yieldedVal, type);
 }
 
@@ -352,12 +353,12 @@ convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
 SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
                                     TypeRange bufferizedTypes,
                                     const DenseSet<int64_t> &tensorIndices,
-                                    BufferizationState &state) {
+                                    const BufferizationOptions &options) {
   return convertTensorValues(
       values, tensorIndices, [&](Value val, int64_t index) {
         return getYieldedBuffer(rewriter, val,
                                 bufferizedTypes[index].cast<BaseMemRefType>(),
-                                state);
+                                options);
       });
 }
 
@@ -472,7 +473,7 @@ struct ForOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto forOp = cast<scf::ForOp>(op);
     Block *oldLoopBody = &forOp.getLoopBody().front();
 
@@ -482,7 +483,7 @@ struct ForOpInterface
 
     // The new memref init_args of the loop.
     SmallVector<Value> initArgs =
-        getBuffers(rewriter, forOp.getIterOpOperands(), state);
+        getBuffers(rewriter, forOp.getIterOpOperands(), options);
 
     // Construct a new scf.for op with memref instead of tensor values.
     auto newForOp = rewriter.create<scf::ForOp>(
@@ -511,7 +512,7 @@ struct ForOpInterface
     auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
     rewriter.setInsertionPoint(yieldOp);
     SmallVector<Value> yieldValues = getYieldedValues(
-        rewriter, yieldOp.getResults(), initArgsTypes, indices, state);
+        rewriter, yieldOp.getResults(), initArgsTypes, indices, options);
     yieldOp.getResultsMutable().assign(yieldValues);
 
     // Replace loop results.
@@ -704,7 +705,7 @@ struct WhileOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto whileOp = cast<scf::WhileOp>(op);
 
     assert(whileOp.getBefore().getBlocks().size() == 1 &&
@@ -722,12 +723,12 @@ struct WhileOpInterface
 
     // The new memref init_args of the loop.
     SmallVector<Value> initArgs =
-        getBuffers(rewriter, whileOp->getOpOperands(), state);
+        getBuffers(rewriter, whileOp->getOpOperands(), options);
 
     // The result types of a WhileOp are the same as the "after" bbArg types.
     SmallVector<Type> argsTypesAfter = llvm::to_vector(
         llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
-          return state.getBufferType(bbArg).cast<Type>();
+          return getBufferType(bbArg, options).cast<Type>();
         }));
 
     // Construct a new scf.while op with memref instead of tensor values.
@@ -761,7 +762,7 @@ struct WhileOpInterface
     // TODO: This could be relaxed for better bufferization results.
     SmallVector<Value> newConditionArgs =
         getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
-                         indicesAfter, state);
+                         indicesAfter, options);
     newConditionOp.getArgsMutable().assign(newConditionArgs);
 
     // Set up new iter_args and move the loop body block to the new op.
@@ -780,7 +781,7 @@ struct WhileOpInterface
     // TODO: This could be relaxed for better bufferization results.
     SmallVector<Value> newYieldValues =
         getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
-                         indicesBefore, state);
+                         indicesBefore, options);
     newYieldOp.getResultsMutable().assign(newYieldValues);
 
     // Replace loop results.
@@ -866,7 +867,7 @@ struct YieldOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto yieldOp = cast<scf::YieldOp>(op);
     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
             yieldOp->getParentOp()))
@@ -954,7 +955,7 @@ struct ForeachThreadOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &b,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     OpBuilder::InsertionGuard g(b);
     auto foreachThreadOp = cast<ForeachThreadOp>(op);
 
@@ -966,7 +967,7 @@ struct ForeachThreadOpInterface
       // Insert copies right before the PerformConcurrentlyOp terminator. They
       // should not be inside terminator (which would be the default insertion
       // point).
-      Value buffer = state.getBuffer(b, insertDest->get());
+      Value buffer = getBuffer(b, insertDest->get(), options);
       newResults.push_back(buffer);
     }
 
@@ -991,8 +992,7 @@ struct ForeachThreadOpInterface
         performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) {
           Location loc = insertOp.getLoc();
           Type srcType = getMemRefType(
-              insertOp.getSource().getType().cast<RankedTensorType>(),
-              state.getOptions());
+              insertOp.getSource().getType().cast<RankedTensorType>(), options);
           // ParallelInsertSliceOp bufferizes to a copy.
           auto srcMemref = b.create<bufferization::ToMemrefOp>(
               loc, srcType, insertOp.getSource());
@@ -1001,8 +1001,8 @@ struct ForeachThreadOpInterface
               loc, destMemref, insertOp.getMixedOffsets(),
               insertOp.getMixedSizes(), insertOp.getMixedStrides());
           // This memcpy will fold away if everything bufferizes in-place.
-          if (failed(state.getOptions().createMemCpy(b, insertOp.getLoc(),
-                                                     srcMemref, subview)))
+          if (failed(options.createMemCpy(b, insertOp.getLoc(), srcMemref,
+                                          subview)))
             return WalkResult::interrupt();
           b.eraseOp(insertOp);
           return WalkResult::advance();
@@ -1022,7 +1022,7 @@ struct PerformConcurrentlyOpInterface
     : public BufferizableOpInterface::ExternalModel<
           PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
   LogicalResult bufferize(Operation *op, RewriterBase &b,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     llvm_unreachable("op does not have any tensor OpOperands / OpResults");
     return failure();
   }
@@ -1110,7 +1110,7 @@ struct ParallelInsertSliceOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &b,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     // Will be bufferized as part of ForeachThreadOp.
     return failure();
   }

diff  --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index 1240b65d1a7e4..177d820ccefcb 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -59,7 +59,7 @@ struct AssumingOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto assumingOp = cast<shape::AssumingOp>(op);
 
     // Compute new result types.
@@ -67,7 +67,7 @@ struct AssumingOpInterface
     for (Type type : assumingOp->getResultTypes()) {
       if (auto tensorType = type.dyn_cast<TensorType>()) {
         // TODO: Infer the result type instead of computing it.
-        newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
+        newResultTypes.push_back(getMemRefType(tensorType, options));
       } else {
         newResultTypes.push_back(type);
       }
@@ -152,7 +152,7 @@ struct AssumingYieldOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     // Op is bufferized as part of AssumingOp.
     return failure();
   }

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 430b2f6df8aa5..7695db6c59b4f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -48,11 +48,11 @@ struct CastOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto castOp = cast<tensor::CastOp>(op);
 
     // The result buffer still has the old (pre-cast) type.
-    Value resultBuffer = state.getBuffer(rewriter, castOp.source());
+    Value resultBuffer = getBuffer(rewriter, castOp.source(), options);
     auto sourceMemRefType = resultBuffer.getType().cast<BaseMemRefType>();
     Attribute memorySpace = sourceMemRefType.getMemorySpace();
     TensorType resultTensorType =
@@ -64,8 +64,8 @@ struct CastOpInterface
         layout = rankedMemRefType.getLayout();
 
     // Compute the new memref type.
-    Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(),
-                                          layout, memorySpace);
+    Type resultMemRefType =
+        getMemRefType(resultTensorType, options, layout, memorySpace);
 
     // Replace the op with a memref.cast.
     assert(memref::CastOp::areCastCompatible(resultBuffer.getType(),
@@ -105,10 +105,10 @@ struct CollapseShapeOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
-    Value buffer = state.getBuffer(rewriter, collapseShapeOp.src());
+    Value buffer = getBuffer(rewriter, collapseShapeOp.src(), options);
     auto bufferType = buffer.getType().cast<MemRefType>();
 
     if (tensorResultType.getRank() == 0) {
@@ -146,7 +146,7 @@ struct CollapseShapeOpInterface
         bufferType, collapseShapeOp.getReassociationIndices());
     if (!canBeCollapsed) {
       // TODO: Create alloc_tensor ops during TensorCopyInsertion.
-      AnalysisState analysisState(state.getOptions());
+      AnalysisState analysisState(options);
       Value tensorAlloc = allocateTensorForShapedValue(
           rewriter, op->getLoc(), collapseShapeOp.src(),
           analysisState.isTensorYielded(collapseShapeOp.result()));
@@ -185,9 +185,9 @@ struct DimOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto dimOp = cast<tensor::DimOp>(op);
-    auto v = state.getBuffer(rewriter, dimOp.source());
+    auto v = getBuffer(rewriter, dimOp.source(), options);
     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
     return success();
   }
@@ -220,10 +220,10 @@ struct ExpandShapeOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
     auto tensorResultType = expandShapeOp.getResultType();
-    auto buffer = state.getBuffer(rewriter, expandShapeOp.src());
+    auto buffer = getBuffer(rewriter, expandShapeOp.src(), options);
 
     // Memref result type is inferred by the builder based on reassociation
     // indices and result shape.
@@ -261,13 +261,13 @@ struct ExtractSliceOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
     Location loc = extractSliceOp.getLoc();
 
     // Even if this op was decided to bufferize out-of-place, do not insert the
     // buffer copy yet. This is done later in this function.
-    auto srcMemref = state.getBuffer(rewriter, extractSliceOp.source());
+    auto srcMemref = getBuffer(rewriter, extractSliceOp.source(), options);
     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
     auto dstTensorType =
         extractSliceOp.result().getType().cast<RankedTensorType>();
@@ -319,9 +319,9 @@ struct ExtractOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto extractOp = cast<tensor::ExtractOp>(op);
-    Value srcMemref = state.getBuffer(rewriter, extractOp.tensor());
+    Value srcMemref = getBuffer(rewriter, extractOp.tensor(), options);
     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
                                                  extractOp.indices());
     return success();
@@ -355,7 +355,7 @@ struct FromElementsOpInterface
     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
                                                     tensor::FromElementsOp> {
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
 
     // Allocate a buffer for the result.
@@ -363,7 +363,7 @@ struct FromElementsOpInterface
     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
     auto shape = tensorType.getShape();
     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
-    AnalysisState analysisState(state.getOptions());
+    AnalysisState analysisState(options);
     Value tensorAlloc = allocateTensorForShapedValue(
         rewriter, loc, fromElementsOp.result(),
         analysisState.isTensorYielded(fromElementsOp.result()),
@@ -410,13 +410,13 @@ struct GenerateOpInterface
     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
                                                     tensor::GenerateOp> {
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto generateOp = cast<tensor::GenerateOp>(op);
     auto tensorType = generateOp.getType().cast<RankedTensorType>();
     // Allocate memory.
     Location loc = op->getLoc();
     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
-    AnalysisState analysisState(state.getOptions());
+    AnalysisState analysisState(options);
     Value tensorAlloc = allocateTensorForShapedValue(
         rewriter, loc, generateOp.result(),
         analysisState.isTensorYielded(generateOp.result()),
@@ -493,9 +493,9 @@ struct InsertOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto insertOp = cast<tensor::InsertOp>(op);
-    Value destMemref = state.getBuffer(rewriter, insertOp.dest());
+    Value destMemref = getBuffer(rewriter, insertOp.dest(), options);
     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
                                      destMemref, insertOp.indices());
     replaceOpWithBufferizedValues(rewriter, op, destMemref);
@@ -645,7 +645,7 @@ struct InsertSliceOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     // insert_slice ops arise from tiling and bufferizing them out-of-place is
     // generally a deal breaker. When used with loops, this ends up cloning the
     // whole tensor on every single iteration and is a symptom of a
@@ -653,7 +653,7 @@ struct InsertSliceOpInterface
     // TODO: be very loud about it or even consider failing the pass.
     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
     Location loc = insertSliceOp.getLoc();
-    Value dstMemref = state.getBuffer(rewriter, insertSliceOp.dest());
+    Value dstMemref = getBuffer(rewriter, insertSliceOp.dest(), options);
 
     // Expand offsets, sizes and strides to the full rank to handle the
     // rank-reducing case.
@@ -681,9 +681,8 @@ struct InsertSliceOpInterface
 
     // Copy tensor. If this tensor.insert_slice has a matching
     // tensor.extract_slice, the copy operation will eventually fold away.
-    auto srcMemref = state.getBuffer(rewriter, insertSliceOp.source());
-    if (failed(
-            state.getOptions().createMemCpy(rewriter, loc, srcMemref, subView)))
+    auto srcMemref = getBuffer(rewriter, insertSliceOp.source(), options);
+    if (failed(options.createMemCpy(rewriter, loc, srcMemref, subView)))
       return failure();
 
     replaceOpWithBufferizedValues(rewriter, op, dstMemref);
@@ -711,9 +710,9 @@ struct RankOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto rankOp = cast<tensor::RankOp>(op);
-    auto v = state.getBuffer(rewriter, rankOp.tensor());
+    auto v = getBuffer(rewriter, rankOp.tensor(), options);
     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
                                                  v);
     return success();
@@ -747,12 +746,12 @@ struct ReshapeOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto reshapeOp = cast<tensor::ReshapeOp>(op);
-    Value srcBuffer = state.getBuffer(rewriter, reshapeOp.source());
-    Value shapeBuffer = state.getBuffer(rewriter, reshapeOp.shape());
+    Value srcBuffer = getBuffer(rewriter, reshapeOp.source(), options);
+    Value shapeBuffer = getBuffer(rewriter, reshapeOp.shape(), options);
     auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
-    auto resultMemRefType = getMemRefType(resultTensorType, state.getOptions());
+    auto resultMemRefType = getMemRefType(resultTensorType, options);
     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
         rewriter, op, resultMemRefType, srcBuffer, shapeBuffer);
     return success();

diff  --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index b7344ee79481d..142e09bbb3da5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -46,11 +46,11 @@ struct TransferReadOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto readOp = cast<vector::TransferReadOp>(op);
     assert(readOp.getShapedType().isa<TensorType>() &&
            "only tensor types expected");
-    Value buffer = state.getBuffer(rewriter, readOp.getSource());
+    Value buffer = getBuffer(rewriter, readOp.getSource(), options);
     replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
         rewriter, readOp, readOp.getVectorType(), buffer, readOp.getIndices(),
         readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
@@ -91,13 +91,13 @@ struct TransferWriteOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto writeOp = cast<vector::TransferWriteOp>(op);
     assert(writeOp.getShapedType().isa<TensorType>() &&
            "only tensor types expected");
 
     // Create a new transfer_write on buffer that doesn't have a return value.
-    Value resultBuffer = state.getBuffer(rewriter, writeOp.getSource());
+    Value resultBuffer = getBuffer(rewriter, writeOp.getSource(), options);
     rewriter.create<vector::TransferWriteOp>(
         writeOp.getLoc(), writeOp.getVector(), resultBuffer,
         writeOp.getIndices(), writeOp.getPermutationMapAttr(),


        


More information about the Mlir-commits mailing list