[Mlir-commits] [mlir] [MLIR] Add bufferization state to `getBufferType` and `resolveConflicts` interface methods (PR #141466)

Michele Scuttari llvmlistbot at llvm.org
Tue May 27 02:27:49 PDT 2025


https://github.com/mscuttari updated https://github.com/llvm/llvm-project/pull/141466

>From f3a97b3dfe3fd7572b860a57949e31cd0172a762 Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Mon, 26 May 2025 11:22:30 +0200
Subject: [PATCH 1/3] [MLIR] Add bufferization state to `getBufferType` and
 `resolveConflicts` interface methods

---
 .../IR/BufferizableOpInterface.h              |  10 +-
 .../IR/BufferizableOpInterface.td             |  11 +-
 .../Bufferization/IR/BufferizationOps.td      |   3 +-
 .../IR/UnstructuredControlFlow.h              |   5 +-
 .../Bufferization/Transforms/Bufferize.h      |   3 +-
 .../Bufferization/Transforms/Transforms.h     |   5 +-
 .../BufferizableOpInterfaceImpl.cpp           |  18 +--
 .../IR/BufferizableOpInterface.cpp            |  51 +++++----
 .../Bufferization/IR/BufferizationOps.cpp     |  17 +--
 .../Bufferization/Transforms/Bufferize.cpp    |  11 +-
 .../FuncBufferizableOpInterfaceImpl.cpp       |  13 ++-
 .../Transforms/OneShotAnalysis.cpp            |   2 +-
 .../Transforms/OneShotModuleBufferize.cpp     |   5 +-
 .../Transforms/TensorCopyInsertion.cpp        |  21 ++--
 .../BufferizableOpInterfaceImpl.cpp           |  26 ++---
 .../BufferizableOpInterfaceImpl.cpp           |   3 +-
 .../BufferizableOpInterfaceImpl.cpp           | 105 ++++++++++--------
 .../BufferizableOpInterfaceImpl.cpp           |   2 +-
 .../SparsificationAndBufferizationPass.cpp    |   6 +-
 .../BufferizableOpInterfaceImpl.cpp           |  83 ++++++++------
 .../BufferizableOpInterfaceImpl.cpp           |  17 ++-
 .../Bufferization/TestTensorCopyInsertion.cpp |   6 +-
 22 files changed, 248 insertions(+), 175 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 43c97d57e1834..328d928c9ebdb 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -598,13 +598,14 @@ class BufferizationState {
 FailureOr<Value>
 allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
                              const BufferizationOptions &options,
-                             bool copy = true);
+                             BufferizationState &state, bool copy = true);
 
 /// Lookup the buffer for the given value. If the value was not bufferized
 /// yet, wrap it in a ToBufferOp. Otherwise, it is the result of a ToTensorOp,
 /// from which the memref operand is returned.
 FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
-                           const BufferizationOptions &options);
+                           const BufferizationOptions &options,
+                           BufferizationState &state);
 
 /// Return the buffer type for a given Value (tensor) after bufferization
 /// without bufferizing any IR.
@@ -615,7 +616,8 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
 ///
 /// This function is a wrapper around BufferizableOpInterface::getBufferType.
 FailureOr<BaseMemRefType> getBufferType(Value value,
-                                        const BufferizationOptions &options);
+                                        const BufferizationOptions &options,
+                                        BufferizationState &state);
 
 /// Return the buffer type for a given Value (tensor) after bufferization
 /// without bufferizing any IR. This function (and not the other overload
@@ -629,6 +631,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
 /// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
 FailureOr<BaseMemRefType> getBufferType(Value value,
                                         const BufferizationOptions &options,
+                                        BufferizationState &state,
                                         SmallVector<Value> &invocationStack);
 
 /// Return "true" if the given op has tensor semantics and should be bufferized.
@@ -709,6 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
 /// places.
 FailureOr<BaseMemRefType>
 defaultGetBufferType(Value value, const BufferizationOptions &options,
+                     BufferizationState &state,
                      SmallVector<Value> &invocationStack);
 
 /// This is the default implementation of
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index b599a9f053215..80f9b72531660 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -381,13 +381,14 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*retType=*/"::llvm::LogicalResult",
         /*methodName=*/"resolveConflicts",
         /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
-                      "const ::mlir::bufferization::AnalysisState &":$state),
+                      "const ::mlir::bufferization::AnalysisState &":$analysisState,
+                      "::mlir::bufferization::BufferizationState &":$bufferizationState),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           auto bufferizableOp =
               ::llvm::cast<BufferizableOpInterface>($_op.getOperation());
           return bufferizableOp.resolveTensorOpOperandConflicts(
-              rewriter, state);
+              rewriter, analysisState, bufferizationState);
         }]
       >,
       InterfaceMethod<
@@ -523,6 +524,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*methodName=*/"getBufferType",
         /*args=*/(ins "::mlir::Value":$value,
                       "const ::mlir::bufferization::BufferizationOptions &":$options,
+                      "::mlir::bufferization::BufferizationState &":$state,
                       "::llvm::SmallVector<::mlir::Value> &":$invocationStack),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
@@ -531,7 +533,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           assert(invocationStack.back() == value &&
                  "inconsistant invocation stack");
           return ::mlir::bufferization::detail::defaultGetBufferType(
-              value, options, invocationStack);
+              value, options, state, invocationStack);
         }]
       >,
       InterfaceMethod<
@@ -616,7 +618,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
     /// form of `bufferization.alloc_tensor` ops.
     ::llvm::LogicalResult resolveTensorOpOperandConflicts(
         ::mlir::RewriterBase &rewriter,
-        const ::mlir::bufferization::AnalysisState &state);
+        const ::mlir::bufferization::AnalysisState &analysisState,
+        ::mlir::bufferization::BufferizationState &bufferizationState);
 
     /// Return `true` if the given OpOperand creates an alias but does neither
     /// read nor write. This implies that `bufferizesToMemoryRead` and
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index dafa4b9b183f2..0ee4f79144158 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -112,6 +112,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
 
     FailureOr<BaseMemRefType> getBufferType(
         Value value, const BufferizationOptions &options,
+        BufferizationState &state,
         SmallVector<Value> &invocationStack);
 
     RankedTensorType getType() {
@@ -471,7 +472,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
     FailureOr<BaseMemRefType> getBufferType(
         Value value, const BufferizationOptions &options,
-        SmallVector<Value> &invocationStack) {
+        BufferizationState &state, SmallVector<Value> &invocationStack) {
       return ::llvm::cast<BaseMemRefType>(getMemref().getType());
     }
   }];
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index cf86b9a23f59e..7c07f705c8435 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -34,12 +34,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+    BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     // Note: The user may want to override this function for OpResults in
     // case the bufferized result type is different from the bufferized type of
     // the aliasing OpOperand (if any).
     if (isa<OpResult>(value))
-      return bufferization::detail::defaultGetBufferType(value, options,
+      return bufferization::detail::defaultGetBufferType(value, options, state,
                                                          invocationStack);
 
     // Compute the buffer type of the block argument by computing the bufferized
@@ -65,7 +66,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         callerType = memrefType;
       } else {
         FailureOr<BaseMemRefType> maybeCallerType =
-            bufferization::getBufferType(opOperand->get(), options,
+            bufferization::getBufferType(opOperand->get(), options, state,
                                          invocationStack);
         if (failed(maybeCallerType))
           return failure();
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 70e3defee0867..c1f5654abbf9b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -62,7 +62,8 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
 /// `BufferizableOpInterface`. The buffer types of tensor block arguments are
 /// computed with `BufferizableOpIntercace::getBufferType`.
 LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
-                                      const BufferizationOptions &options);
+                                      const BufferizationOptions &options,
+                                      BufferizationState &state);
 
 } // namespace bufferization
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index a4ee893ca5341..e587753ddebee 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -75,12 +75,15 @@ void hoistBuffersFromLoops(Operation *op);
 /// additional buffer allocations.
 LogicalResult insertTensorCopies(Operation *op,
                                  const OneShotBufferizationOptions &options,
+                                 BufferizationState &bufferizationState,
                                  BufferizationStatistics *statistics = nullptr);
 
 /// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
 /// After applying this transform, the IR can be bufferized without inserting
 /// additional buffer allocations.
-LogicalResult insertTensorCopies(Operation *op, const AnalysisState &state);
+LogicalResult insertTensorCopies(Operation *op,
+                                 const AnalysisState &analysisState,
+                                 BufferizationState &bufferizationState);
 
 /// Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor
 /// ops.
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index f646326ffc58f..0389a984e169c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -90,7 +90,8 @@ struct IndexCastOpInterface
     auto castOp = cast<arith::IndexCastOp>(op);
     auto resultTensorType = cast<TensorType>(castOp.getType());
 
-    FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
+    FailureOr<Value> source =
+        getBuffer(rewriter, castOp.getIn(), options, state);
     if (failed(source))
       return failure();
     auto sourceType = cast<BaseMemRefType>(source->getType());
@@ -151,9 +152,9 @@ struct SelectOpInterface
     // the moment (one for each tensor). When copying the op result, only one
     // copy would be needed.
     FailureOr<Value> maybeTrueBuffer =
-        getBuffer(rewriter, selectOp.getTrueValue(), options);
+        getBuffer(rewriter, selectOp.getTrueValue(), options, state);
     FailureOr<Value> maybeFalseBuffer =
-        getBuffer(rewriter, selectOp.getFalseValue(), options);
+        getBuffer(rewriter, selectOp.getFalseValue(), options, state);
     if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
       return failure();
     Value trueBuffer = *maybeTrueBuffer;
@@ -164,7 +165,7 @@ struct SelectOpInterface
     // both of them to the most dynamic MemRef type.
     if (trueBuffer.getType() != falseBuffer.getType()) {
       auto targetType =
-          bufferization::getBufferType(selectOp.getResult(), options);
+          bufferization::getBufferType(selectOp.getResult(), options, state);
       if (failed(targetType))
         return failure();
       if (trueBuffer.getType() != *targetType)
@@ -182,13 +183,14 @@ struct SelectOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto selectOp = cast<arith::SelectOp>(op);
     assert(value == selectOp.getResult() && "invalid value");
-    auto trueType = bufferization::getBufferType(selectOp.getTrueValue(),
-                                                 options, invocationStack);
-    auto falseType = bufferization::getBufferType(selectOp.getFalseValue(),
-                                                  options, invocationStack);
+    auto trueType = bufferization::getBufferType(
+        selectOp.getTrueValue(), options, state, invocationStack);
+    auto falseType = bufferization::getBufferType(
+        selectOp.getFalseValue(), options, state, invocationStack);
     if (failed(trueType) || failed(falseType))
       return failure();
     if (*trueType == *falseType)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 14fa4c1ed8159..7d67d4a33ac32 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -165,7 +165,7 @@ Operation *bufferization::getOwnerOfValue(Value value) {
 /// allocated.
 FailureOr<Value> bufferization::allocateTensorForShapedValue(
     OpBuilder &b, Location loc, Value shapedValue,
-    const BufferizationOptions &options, bool copy) {
+    const BufferizationOptions &options, BufferizationState &state, bool copy) {
   Value tensor;
   if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
     tensor = shapedValue;
@@ -210,7 +210,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
   // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
   if (copy)
     return allocTensorOp.getResult();
-  FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
+  FailureOr<BaseMemRefType> copyBufferType =
+      getBufferType(tensor, options, state);
   if (failed(copyBufferType))
     return failure();
   std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
@@ -222,7 +223,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
 }
 
 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
-    RewriterBase &rewriter, const AnalysisState &state) {
+    RewriterBase &rewriter, const AnalysisState &analysisState,
+    BufferizationState &bufferizationState) {
   OpBuilder::InsertionGuard g(rewriter);
   Operation *op = getOperation();
   SmallVector<OpOperand *> outOfPlaceOpOperands;
@@ -235,16 +237,18 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
     Type operandType = opOperand.get().getType();
     if (!llvm::isa<TensorType>(operandType))
       continue;
-    if (state.isInPlace(opOperand))
+    if (analysisState.isInPlace(opOperand))
       continue;
     if (llvm::isa<UnrankedTensorType>(operandType))
       return op->emitError("copying of unranked tensors is not implemented");
 
-    AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
+    AliasingValueList aliasingValues =
+        analysisState.getAliasingValues(opOperand);
     if (aliasingValues.getNumAliases() == 1 &&
         isa<OpResult>(aliasingValues.getAliases()[0].value) &&
-        !state.bufferizesToMemoryWrite(opOperand) &&
-        state.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
+        !analysisState.bufferizesToMemoryWrite(opOperand) &&
+        analysisState
+                .getAliasingOpOperands(aliasingValues.getAliases()[0].value)
                 .getNumAliases() == 1 &&
         !isa<UnrankedTensorType>(
             aliasingValues.getAliases()[0].value.getType())) {
@@ -256,12 +260,12 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
       // cannot be copied at the moment).
       Value value = aliasingValues.getAliases()[0].value;
       outOfPlaceValues.push_back(value);
-      if (!state.canOmitTensorCopy(opOperand))
+      if (!analysisState.canOmitTensorCopy(opOperand))
         copiedOpValues.insert(value);
     } else {
       // In all other cases, make a copy of the OpOperand.
       outOfPlaceOpOperands.push_back(&opOperand);
-      if (!state.canOmitTensorCopy(opOperand))
+      if (!analysisState.canOmitTensorCopy(opOperand))
         copiedOpOperands.insert(&opOperand);
     }
   }
@@ -270,8 +274,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
   rewriter.setInsertionPoint(op);
   for (OpOperand *opOperand : outOfPlaceOpOperands) {
     FailureOr<Value> copy = allocateTensorForShapedValue(
-        rewriter, op->getLoc(), opOperand->get(), state.getOptions(),
-        copiedOpOperands.contains(opOperand));
+        rewriter, op->getLoc(), opOperand->get(), analysisState.getOptions(),
+        bufferizationState, copiedOpOperands.contains(opOperand));
     if (failed(copy))
       return failure();
     rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
@@ -281,8 +285,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
   rewriter.setInsertionPointAfter(op);
   for (Value value : outOfPlaceValues) {
     FailureOr<Value> copy = allocateTensorForShapedValue(
-        rewriter, op->getLoc(), value, state.getOptions(),
-        copiedOpValues.count(value));
+        rewriter, op->getLoc(), value, analysisState.getOptions(),
+        bufferizationState, copiedOpValues.count(value));
     if (failed(copy))
       return failure();
     SmallVector<OpOperand *> uses = llvm::to_vector(
@@ -665,7 +669,8 @@ static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
 }
 
 FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
-                                          const BufferizationOptions &options) {
+                                          const BufferizationOptions &options,
+                                          BufferizationState &state) {
 #ifndef NDEBUG
   auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
   assert(tensorType && "unexpected non-tensor type");
@@ -678,7 +683,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
   // Insert to_buffer op.
   OpBuilder::InsertionGuard g(rewriter);
   setInsertionPointAfter(rewriter, value);
-  FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
+  FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
   if (failed(memrefType))
     return failure();
   ensureToBufferOpIsValid(value, *memrefType);
@@ -689,14 +694,16 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
 FailureOr<BaseMemRefType>
-bufferization::getBufferType(Value value, const BufferizationOptions &options) {
+bufferization::getBufferType(Value value, const BufferizationOptions &options,
+                             BufferizationState &state) {
   SmallVector<Value> invocationStack;
-  return getBufferType(value, options, invocationStack);
+  return getBufferType(value, options, state, invocationStack);
 }
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
 FailureOr<BaseMemRefType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options,
+                             BufferizationState &state,
                              SmallVector<Value> &invocationStack) {
   assert(llvm::isa<TensorType>(value.getType()) &&
          "unexpected non-tensor type");
@@ -708,7 +715,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
   Operation *op = getOwnerOfValue(value);
   auto bufferizableOp = options.dynCastBufferizableOp(op);
   if (bufferizableOp)
-    return bufferizableOp.getBufferType(value, options, invocationStack);
+    return bufferizableOp.getBufferType(value, options, state, invocationStack);
 
   // Op is not bufferizable.
   auto memSpace =
@@ -944,6 +951,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
 
 FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     Value value, const BufferizationOptions &options,
+    BufferizationState &bufferizationState,
     SmallVector<Value> &invocationStack) {
   assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
 
@@ -954,14 +962,15 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   // Value is an OpResult.
   Operation *op = getOwnerOfValue(value);
   auto opResult = llvm::cast<OpResult>(value);
-  AnalysisState state(options);
-  AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
+  AnalysisState analysisState(options);
+  AliasingOpOperandList aliases = analysisState.getAliasingOpOperands(opResult);
   if (aliases.getNumAliases() > 0 &&
       aliases.getAliases()[0].relation == BufferRelation::Equivalent) {
     // If the OpResult has an equivalent OpOperand, both OpResult and
     // OpOperand bufferize to the exact same buffer type.
     Value equivalentOperand = aliases.getAliases().front().opOperand->get();
-    return getBufferType(equivalentOperand, options, invocationStack);
+    return getBufferType(equivalentOperand, options, bufferizationState,
+                         invocationStack);
   }
 
   // If we do not know the memory space and there is no default memory space,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 91eccb0ab7430..41b86437e11cf 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -163,14 +163,15 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
   // Get "copy" buffer.
   Value copyBuffer;
   if (getCopy()) {
-    FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
+    FailureOr<Value> maybeCopyBuffer =
+        getBuffer(rewriter, getCopy(), options, state);
     if (failed(maybeCopyBuffer))
       return failure();
     copyBuffer = *maybeCopyBuffer;
   }
 
   // Create memory allocation.
-  auto allocType = bufferization::getBufferType(getResult(), options);
+  auto allocType = bufferization::getBufferType(getResult(), options, state);
   if (failed(allocType))
     return failure();
   SmallVector<Value> dynamicDims = getDynamicSizes();
@@ -223,6 +224,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
 
 FailureOr<BaseMemRefType>
 AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
+                             BufferizationState &state,
                              SmallVector<Value> &invocationStack) {
   assert(value == getResult() && "invalid value");
 
@@ -231,8 +233,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
   if (getMemorySpace().has_value()) {
     memorySpace = *getMemorySpace();
   } else if (getCopy()) {
-    auto copyBufferType =
-        bufferization::getBufferType(getCopy(), options, invocationStack);
+    auto copyBufferType = bufferization::getBufferType(getCopy(), options,
+                                                       state, invocationStack);
     if (failed(copyBufferType))
       return failure();
     memorySpace = copyBufferType->getMemorySpace();
@@ -532,7 +534,7 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
                                          const BufferizationOptions &options,
                                          BufferizationState &state) {
-  FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
+  FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options, state);
   if (failed(buffer))
     return failure();
   rewriter.create<memref::DeallocOp>(getLoc(), *buffer);
@@ -583,7 +585,8 @@ MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
   bool tensorDest = isa<TensorType>(getDest().getType());
   Value buffer;
   if (tensorDest) {
-    FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
+    FailureOr<Value> maybeBuffer =
+        getBuffer(rewriter, getDest(), options, state);
     if (failed(maybeBuffer))
       return failure();
     buffer = *maybeBuffer;
@@ -591,7 +594,7 @@ MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
     assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
     buffer = getDest();
   }
-  auto srcBuffer = getBuffer(rewriter, getSource(), options);
+  auto srcBuffer = getBuffer(rewriter, getSource(), options, state);
   if (failed(srcBuffer))
     return failure();
   if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 67f373d912dd4..c7681d309a4af 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -280,8 +280,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
                                          BufferizationState &bufferizationState,
                                          BufferizationStatistics *statistics) {
   if (options.copyBeforeWrite) {
-    AnalysisState state(options);
-    if (failed(insertTensorCopies(op, state)))
+    AnalysisState analysisState(options);
+    if (failed(insertTensorCopies(op, analysisState, bufferizationState)))
       return failure();
   }
 
@@ -396,7 +396,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
 
 LogicalResult
 bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
-                                       const BufferizationOptions &options) {
+                                       const BufferizationOptions &options,
+                                       BufferizationState &state) {
   OpBuilder::InsertionGuard g(rewriter);
   auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp());
   if (!bufferizableOp)
@@ -412,7 +413,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
     }
 
     FailureOr<BaseMemRefType> memrefType =
-        bufferization::getBufferType(bbArg, options);
+        bufferization::getBufferType(bbArg, options, state);
     if (failed(memrefType))
       return failure();
     newTypes.push_back(*memrefType);
@@ -463,7 +464,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
         continue;
       }
       FailureOr<BaseMemRefType> operandBufferType =
-          bufferization::getBufferType(operand, options);
+          bufferization::getBufferType(operand, options, state);
       if (failed(operandBufferType))
         return failure();
       rewriter.setInsertionPointAfterValue(operand);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 6210f1d787bf4..9a0a85a71debd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -213,6 +213,7 @@ struct CallOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto callOp = cast<func::CallOp>(op);
 
@@ -255,7 +256,7 @@ struct CallOpInterface
 
       // Returning a memref.
       FailureOr<BaseMemRefType> resultType =
-          bufferization::getBufferType(result, options);
+          bufferization::getBufferType(result, options, state);
       if (failed(resultType))
         return failure();
       resultTypes.push_back(*resultType);
@@ -278,7 +279,7 @@ struct CallOpInterface
 
       // Retrieve buffers for tensor operands.
       FailureOr<Value> maybeBuffer =
-          getBuffer(rewriter, opOperand.get(), options);
+          getBuffer(rewriter, opOperand.get(), options, state);
       if (failed(maybeBuffer))
         return failure();
       Value buffer = *maybeBuffer;
@@ -291,7 +292,8 @@ struct CallOpInterface
         // result type.
         FailureOr<BaseMemRefType> maybeMemRefType =
             bufferization::getBufferType(
-                funcOp.getArgument(opOperand.getOperandNumber()), options);
+                funcOp.getArgument(opOperand.getOperandNumber()), options,
+                state);
         if (failed(maybeMemRefType))
           return failure();
         memRefType = *maybeMemRefType;
@@ -396,6 +398,7 @@ struct FuncOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto funcOp = cast<FuncOp>(op);
     auto bbArg = cast<BlockArgument>(value);
@@ -406,7 +409,7 @@ struct FuncOpInterface
                                           options);
 
     return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
-        getBufferType(op, value, options, invocationStack);
+        getBufferType(op, value, options, state, invocationStack);
   }
 
   /// Rewrite function bbArgs and return values into buffer form. This function
@@ -459,7 +462,7 @@ struct FuncOpInterface
     // 1. Bufferize every block.
     for (Block &block : funcOp.getBody())
       if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
-                                                        options)))
+                                                        options, state)))
         return failure();
 
     // 2. Bufferize the operands of the all return op.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index de820e9c8f8af..33a922d59224b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -1379,7 +1379,7 @@ LogicalResult bufferization::runOneShotBufferize(
     // Run One-Shot Analysis and insert buffer copies (on the tensor level)
     // only where needed. This is the default and much more efficient than
     // copy-before-write.
-    if (failed(insertTensorCopies(op, options, statistics)))
+    if (failed(insertTensorCopies(op, options, state, statistics)))
       return failure();
 
     // If test-analysis-only is set, the IR was annotated with RaW conflict
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 90ceea4d69680..dee2af8271ce8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -584,7 +584,7 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
          "invalid combination of bufferization flags");
   if (!options.copyBeforeWrite) {
     if (options.noAnalysisFuncFilter.empty()) {
-      if (failed(insertTensorCopies(moduleOp, options, statistics)))
+      if (failed(insertTensorCopies(moduleOp, options, state, statistics)))
         return failure();
     } else {
       // FuncOps whose names are specified in options.noAnalysisFuncFilter will
@@ -600,7 +600,8 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
       };
       OneShotBufferizationOptions updatedOptions(options);
       updatedOptions.opFilter.denyOperation(analysisFilterFn);
-      if (failed(insertTensorCopies(moduleOp, updatedOptions, statistics)))
+      if (failed(
+              insertTensorCopies(moduleOp, updatedOptions, state, statistics)))
         return failure();
     }
   }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index 4326b19f3104d..d971ed5b0f71c 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -28,28 +28,29 @@ using namespace mlir::bufferization;
 
 LogicalResult mlir::bufferization::insertTensorCopies(
     Operation *op, const OneShotBufferizationOptions &options,
+    BufferizationState &bufferizationState,
     BufferizationStatistics *statistics) {
-  OneShotAnalysisState state(op, options);
+  OneShotAnalysisState analysisState(op, options);
   // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
   // analysis depending on whether function boundary bufferization is enabled or
   // not.
   if (options.bufferizeFunctionBoundaries) {
-    if (failed(analyzeModuleOp(cast<ModuleOp>(op), state, statistics)))
+    if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics)))
       return failure();
   } else {
-    if (failed(analyzeOp(op, state, statistics)))
+    if (failed(analyzeOp(op, analysisState, statistics)))
       return failure();
   }
 
   if (options.testAnalysisOnly)
     return success();
 
-  return insertTensorCopies(op, state);
+  return insertTensorCopies(op, analysisState, bufferizationState);
 }
 
-LogicalResult
-mlir::bufferization::insertTensorCopies(Operation *op,
-                                        const AnalysisState &state) {
+LogicalResult mlir::bufferization::insertTensorCopies(
+    Operation *op, const AnalysisState &analysisState,
+    BufferizationState &bufferizationState) {
   IRRewriter rewriter(op->getContext());
 
   // It may be more efficient to walk in pre-order here, but the current
@@ -62,14 +63,16 @@ mlir::bufferization::insertTensorCopies(Operation *op,
         nestedOp->getParentWithTrait<OpTrait::SymbolTable>() != op)
       return WalkResult::skip();
 
-    auto bufferizableOp = state.getOptions().dynCastBufferizableOp(nestedOp);
+    auto bufferizableOp =
+        analysisState.getOptions().dynCastBufferizableOp(nestedOp);
     if (!bufferizableOp)
       return WalkResult::skip();
 
     // Find inplacability conflicts and resolve them. (Typically with explicit
     // tensor copies in the form of AllocTensorOps.)
     rewriter.setInsertionPoint(nestedOp);
-    if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
+    if (failed(bufferizableOp.resolveConflicts(rewriter, analysisState,
+                                               bufferizationState)))
       return WalkResult::interrupt();
 
     return WalkResult::advance();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index b6a498a57c036..ce355e96ee694 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -24,10 +24,9 @@ using namespace mlir::bufferization;
 namespace {
 
 /// Generic conversion for any DestinationStyleOpInterface on tensors.
-static LogicalResult
-bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
-                                     DestinationStyleOpInterface op,
-                                     const BufferizationOptions &options) {
+static LogicalResult bufferizeDestinationStyleOpInterface(
+    RewriterBase &rewriter, DestinationStyleOpInterface op,
+    const BufferizationOptions &options, BufferizationState &state) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(op);
@@ -49,7 +48,8 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
       newInputBuffers.push_back(opOperand->get());
       continue;
     }
-    FailureOr<Value> buffer = getBuffer(rewriter, opOperand->get(), options);
+    FailureOr<Value> buffer =
+        getBuffer(rewriter, opOperand->get(), options, state);
     if (failed(buffer))
       return failure();
     newInputBuffers.push_back(*buffer);
@@ -60,7 +60,7 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
   for (OpResult opResult : op->getOpResults()) {
     OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber());
     FailureOr<Value> resultBuffer =
-        getBuffer(rewriter, opOperand->get(), options);
+        getBuffer(rewriter, opOperand->get(), options, state);
     if (failed(resultBuffer))
       return failure();
     newOutputBuffers.push_back(*resultBuffer);
@@ -76,10 +76,10 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
   // new op. Since the new op does not have any tensor results, it does not
   // return anything.
   assert(op->getNumRegions() == 1 && "expected that op has 1 region");
-  OperationState state(op->getLoc(), op->getName(), newOperands, TypeRange{},
-                       op->getAttrs());
-  state.addRegion();
-  Operation *newOp = Operation::create(state);
+  OperationState opState(op->getLoc(), op->getName(), newOperands, TypeRange{},
+                         op->getAttrs());
+  opState.addRegion();
+  Operation *newOp = Operation::create(opState);
   newOp->getRegion(0).getBlocks().splice(newOp->getRegion(0).begin(),
                                          op->getRegion(0).getBlocks());
 
@@ -151,7 +151,7 @@ struct LinalgOpInterface
                           const BufferizationOptions &options,
                           BufferizationState &state) const {
     return bufferizeDestinationStyleOpInterface(
-        rewriter, cast<DestinationStyleOpInterface>(op), options);
+        rewriter, cast<DestinationStyleOpInterface>(op), options, state);
   }
 };
 
@@ -179,11 +179,11 @@ struct SoftmaxOpInterface
                           BufferizationState &state) const {
     auto softmaxOp = cast<linalg::SoftmaxOp>(op);
     FailureOr<Value> inputBuffer =
-        getBuffer(rewriter, softmaxOp.getInput(), options);
+        getBuffer(rewriter, softmaxOp.getInput(), options, state);
     if (failed(inputBuffer))
       return failure();
     FailureOr<Value> outputBuffer =
-        getBuffer(rewriter, softmaxOp.getOutput(), options);
+        getBuffer(rewriter, softmaxOp.getOutput(), options, state);
     if (failed(outputBuffer))
       return failure();
     rewriter.create<linalg::SoftmaxOp>(softmaxOp.getLoc(),
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
index a69bc9e5088ae..ff6af63eee531 100644
--- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -138,7 +138,8 @@ struct GlobalStoreOpInterface
     auto targetMemref = rewriter.create<memref::GetGlobalOp>(
         loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
 
-    auto sourceMemref = getBuffer(rewriter, globalStoreOp.getValue(), options);
+    auto sourceMemref =
+        getBuffer(rewriter, globalStoreOp.getValue(), options, state);
     if (failed(sourceMemref)) {
       return failure();
     }
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 3ff1f5c49aece..59c240c62f934 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -104,11 +104,12 @@ struct ConditionOpInterface
     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
       Value value = it.value();
       if (isa<TensorType>(value.getType())) {
-        FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
+        FailureOr<Value> maybeBuffer =
+            getBuffer(rewriter, value, options, state);
         if (failed(maybeBuffer))
           return failure();
         FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
-            whileOp.getAfterArguments()[it.index()], options);
+            whileOp.getAfterArguments()[it.index()], options, state);
         if (failed(resultType))
           return failure();
         Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
@@ -196,7 +197,7 @@ struct ExecuteRegionOpInterface
     // Bufferize every block.
     for (Block &block : newOp.getRegion())
       if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
-                                                        options)))
+                                                        options, state)))
         return failure();
 
     // Update all uses of the old op.
@@ -251,7 +252,7 @@ struct IfOpInterface
         newTypes.push_back(result.getType());
         continue;
       }
-      auto bufferType = bufferization::getBufferType(result, options);
+      auto bufferType = bufferization::getBufferType(result, options, state);
       if (failed(bufferType))
         return failure();
       newTypes.push_back(*bufferType);
@@ -275,6 +276,7 @@ struct IfOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto ifOp = cast<scf::IfOp>(op);
     auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
@@ -290,8 +292,8 @@ struct IfOpInterface
       // True branch was already bufferized.
       thenBufferType = cast<BaseMemRefType>(thenValue.getType());
     } else {
-      auto maybeBufferType =
-          bufferization::getBufferType(thenValue, options, invocationStack);
+      auto maybeBufferType = bufferization::getBufferType(
+          thenValue, options, state, invocationStack);
       if (failed(maybeBufferType))
         return failure();
       thenBufferType = *maybeBufferType;
@@ -300,8 +302,8 @@ struct IfOpInterface
       // False branch was already bufferized.
       elseBufferType = cast<BaseMemRefType>(elseValue.getType());
     } else {
-      auto maybeBufferType =
-          bufferization::getBufferType(elseValue, options, invocationStack);
+      auto maybeBufferType = bufferization::getBufferType(
+          elseValue, options, state, invocationStack);
       if (failed(maybeBufferType))
         return failure();
       elseBufferType = *maybeBufferType;
@@ -362,7 +364,7 @@ struct IndexSwitchOpInterface
         newTypes.push_back(result.getType());
         continue;
       }
-      auto bufferType = bufferization::getBufferType(result, options);
+      auto bufferType = bufferization::getBufferType(result, options, state);
       if (failed(bufferType))
         return failure();
       newTypes.push_back(*bufferType);
@@ -390,6 +392,7 @@ struct IndexSwitchOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto switchOp = cast<scf::IndexSwitchOp>(op);
     assert(value.getDefiningOp() == op && "invalid value");
@@ -401,8 +404,8 @@ struct IndexSwitchOpInterface
       Value yieldedValue = yieldOp->getOperand(resultNum);
       if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
         return bufferType;
-      auto maybeBufferType =
-          bufferization::getBufferType(yieldedValue, options, invocationStack);
+      auto maybeBufferType = bufferization::getBufferType(
+          yieldedValue, options, state, invocationStack);
       if (failed(maybeBufferType))
         return failure();
       return maybeBufferType;
@@ -468,12 +471,12 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
 /// given OpOperands. If an operand is not a tensor, return the original value.
 static FailureOr<SmallVector<Value>>
 getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
-           const BufferizationOptions &options) {
+           const BufferizationOptions &options, BufferizationState &state) {
   SmallVector<Value> result;
   for (OpOperand &opOperand : operands) {
     if (isa<TensorType>(opOperand.get().getType())) {
       FailureOr<Value> resultBuffer =
-          getBuffer(rewriter, opOperand.get(), options);
+          getBuffer(rewriter, opOperand.get(), options, state);
       if (failed(resultBuffer))
         return failure();
       result.push_back(*resultBuffer);
@@ -521,10 +524,11 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
 /// layout map and a cast must be inserted.
 static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
     Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
-    const BufferizationOptions &options, SmallVector<Value> &invocationStack) {
+    const BufferizationOptions &options, BufferizationState &state,
+    SmallVector<Value> &invocationStack) {
   // Determine the buffer type of the init_arg.
   auto initArgBufferType =
-      bufferization::getBufferType(initArg, options, invocationStack);
+      bufferization::getBufferType(initArg, options, state, invocationStack);
   if (failed(initArgBufferType))
     return failure();
 
@@ -550,8 +554,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
   } else {
     // Note: This typically triggers a recursive call for the buffer type of
     // the iter_arg.
-    auto maybeBufferType =
-        bufferization::getBufferType(yieldedValue, options, invocationStack);
+    auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
+                                                        state, invocationStack);
     if (failed(maybeBufferType))
       return failure();
     yieldedValueBufferType = *maybeBufferType;
@@ -650,12 +654,14 @@ struct ForOpInterface
   }
 
   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
-                                 const AnalysisState &state) const {
+                                 const AnalysisState &analysisState,
+                                 BufferizationState &bufferizationState) const {
     auto bufferizableOp = cast<BufferizableOpInterface>(op);
-    if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
+    if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
+            rewriter, analysisState, bufferizationState)))
       return failure();
 
-    if (state.getOptions().copyBeforeWrite)
+    if (analysisState.getOptions().copyBeforeWrite)
       return success();
 
     // According to the `getAliasing...` implementations, a bufferized OpResult
@@ -683,12 +689,13 @@ struct ForOpInterface
           doesNotAliasExternalValue(
               it.value(), &forOp.getRegion(),
               /*exceptions=*/forOp.getRegionIterArg(it.index()),
-              static_cast<const OneShotAnalysisState &>(state))) {
+              static_cast<const OneShotAnalysisState &>(analysisState))) {
         yieldValues.push_back(it.value());
         continue;
       }
       FailureOr<Value> alloc = allocateTensorForShapedValue(
-          rewriter, yieldOp.getLoc(), it.value(), state.getOptions());
+          rewriter, yieldOp.getLoc(), it.value(), analysisState.getOptions(),
+          bufferizationState);
       if (failed(alloc))
         return failure();
       yieldValues.push_back(*alloc);
@@ -701,6 +708,7 @@ struct ForOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto forOp = cast<scf::ForOp>(op);
     assert(getOwnerOfValue(value) == op && "invalid value");
@@ -709,7 +717,8 @@ struct ForOpInterface
     if (auto opResult = dyn_cast<OpResult>(value)) {
       // The type of an OpResult must match the corresponding iter_arg type.
       BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
-      return bufferization::getBufferType(bbArg, options, invocationStack);
+      return bufferization::getBufferType(bbArg, options, state,
+                                          invocationStack);
     }
 
     // Compute result/argument number.
@@ -722,7 +731,7 @@ struct ForOpInterface
     BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
     Value initArg = forOp.getInitArgs()[resultNum];
     return computeLoopRegionIterArgBufferType(
-        op, iterArg, initArg, yieldedValue, options, invocationStack);
+        op, iterArg, initArg, yieldedValue, options, state, invocationStack);
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -737,7 +746,7 @@ struct ForOpInterface
 
     // The new memref init_args of the loop.
     FailureOr<SmallVector<Value>> maybeInitArgs =
-        getBuffers(rewriter, forOp.getInitArgsMutable(), options);
+        getBuffers(rewriter, forOp.getInitArgsMutable(), options, state);
     if (failed(maybeInitArgs))
       return failure();
     SmallVector<Value> initArgs = *maybeInitArgs;
@@ -752,7 +761,7 @@ struct ForOpInterface
         castedInitArgs.push_back(initArg);
         continue;
       }
-      auto targetType = bufferization::getBufferType(result, options);
+      auto targetType = bufferization::getBufferType(result, options, state);
       if (failed(targetType))
         return failure();
       castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
@@ -892,12 +901,14 @@ struct WhileOpInterface
   }
 
   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
-                                 const AnalysisState &state) const {
+                                 const AnalysisState &analysisState,
+                                 BufferizationState &bufferizationState) const {
     auto bufferizableOp = cast<BufferizableOpInterface>(op);
-    if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
+    if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
+            rewriter, analysisState, bufferizationState)))
       return failure();
 
-    if (state.getOptions().copyBeforeWrite)
+    if (analysisState.getOptions().copyBeforeWrite)
       return success();
 
     // According to the `getAliasing...` implementations, a bufferized OpResult
@@ -914,9 +925,10 @@ struct WhileOpInterface
     // For every yielded value, is the value equivalent to its corresponding
     // bbArg?
     DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
-        whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
-    DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
-        whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
+        whileOp.getBeforeArguments(), conditionOp.getArgs(), analysisState);
+    DenseSet<int64_t> equivalentYieldsAfter =
+        getEquivalentBuffers(whileOp.getAfterArguments(),
+                             whileOp.getYieldOp().getResults(), analysisState);
 
     // Update "before" region.
     rewriter.setInsertionPoint(conditionOp);
@@ -931,7 +943,8 @@ struct WhileOpInterface
         continue;
       }
       FailureOr<Value> alloc = allocateTensorForShapedValue(
-          rewriter, conditionOp.getLoc(), value, state.getOptions());
+          rewriter, conditionOp.getLoc(), value, analysisState.getOptions(),
+          bufferizationState);
       if (failed(alloc))
         return failure();
       beforeYieldValues.push_back(*alloc);
@@ -956,7 +969,7 @@ struct WhileOpInterface
 
     // The new memref init_args of the loop.
     FailureOr<SmallVector<Value>> maybeInitArgs =
-        getBuffers(rewriter, whileOp.getInitsMutable(), options);
+        getBuffers(rewriter, whileOp.getInitsMutable(), options, state);
     if (failed(maybeInitArgs))
       return failure();
     SmallVector<Value> initArgs = *maybeInitArgs;
@@ -971,7 +984,7 @@ struct WhileOpInterface
         castedInitArgs.push_back(initArg);
         continue;
       }
-      auto targetType = bufferization::getBufferType(beforeArg, options);
+      auto targetType = bufferization::getBufferType(beforeArg, options, state);
       if (failed(targetType))
         return failure();
       castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
@@ -984,7 +997,7 @@ struct WhileOpInterface
             return bbArg.getType();
           // TODO: error handling
           return llvm::cast<Type>(
-              *bufferization::getBufferType(bbArg, options));
+              *bufferization::getBufferType(bbArg, options, state));
         }));
 
     // Construct a new scf.while op with memref instead of tensor values.
@@ -1029,6 +1042,7 @@ struct WhileOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto whileOp = cast<scf::WhileOp>(op);
     assert(getOwnerOfValue(value) == op && "invalid value");
@@ -1041,7 +1055,7 @@ struct WhileOpInterface
         auto yieldOp = whileOp.getYieldOp();
         Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
         return computeLoopRegionIterArgBufferType(
-            op, bbArg, initArg, yieldedValue, options, invocationStack);
+            op, bbArg, initArg, yieldedValue, options, state, invocationStack);
       }
     }
 
@@ -1062,7 +1076,7 @@ struct WhileOpInterface
       // scf.condition was already bufferized.
       return cast<BaseMemRefType>(conditionYieldedVal.getType());
     }
-    return bufferization::getBufferType(conditionYieldedVal, options,
+    return bufferization::getBufferType(conditionYieldedVal, options, state,
                                         invocationStack);
   }
 
@@ -1161,7 +1175,8 @@ struct YieldOpInterface
     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
       Value value = it.value();
       if (isa<TensorType>(value.getType())) {
-        FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
+        FailureOr<Value> maybeBuffer =
+            getBuffer(rewriter, value, options, state);
         if (failed(maybeBuffer))
           return failure();
         Value buffer = *maybeBuffer;
@@ -1169,14 +1184,14 @@ struct YieldOpInterface
         if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
                 yieldOp->getParentOp())) {
           FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
-              yieldOp->getParentOp()->getResult(it.index()), options);
+              yieldOp->getParentOp()->getResult(it.index()), options, state);
           if (failed(resultType))
             return failure();
           buffer = castBuffer(rewriter, buffer, *resultType);
         } else if (auto whileOp =
                        dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
           FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
-              whileOp.getBeforeArguments()[it.index()], options);
+              whileOp.getBeforeArguments()[it.index()], options, state);
           if (failed(resultType))
             return failure();
           buffer = castBuffer(rewriter, buffer, *resultType);
@@ -1236,7 +1251,7 @@ struct ForallOpInterface
     // Get buffers for all output operands.
     SmallVector<Value> buffers;
     for (Value out : forallOp.getOutputs()) {
-      FailureOr<Value> buffer = getBuffer(rewriter, out, options);
+      FailureOr<Value> buffer = getBuffer(rewriter, out, options, state);
       if (failed(buffer))
         return failure();
       buffers.push_back(*buffer);
@@ -1283,6 +1298,7 @@ struct ForallOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto forallOp = cast<ForallOp>(op);
 
@@ -1290,13 +1306,14 @@ struct ForallOpInterface
       // A tensor block argument has the same bufferized type as the
       // corresponding output operand.
       return bufferization::getBufferType(
-          forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack);
+          forallOp.getTiedOpOperand(bbArg)->get(), options, state,
+          invocationStack);
 
     // The bufferized result type is the same as the bufferized type of the
     // corresponding output operand.
     return bufferization::getBufferType(
         forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
-        invocationStack);
+        state, invocationStack);
   }
 
   bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index e8cab76d3c753..dc91117a51936 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -119,7 +119,7 @@ struct AssumingYieldOpInterface
     SmallVector<Value> newResults;
     for (Value value : yieldOp.getOperands()) {
       if (isa<TensorType>(value.getType())) {
-        FailureOr<Value> buffer = getBuffer(rewriter, value, options);
+        FailureOr<Value> buffer = getBuffer(rewriter, value, options, state);
         if (failed(buffer))
           return failure();
         newResults.push_back(*buffer);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 7c7c64f2aef01..a3ab53d818115 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -152,8 +152,10 @@ class SparsificationAndBufferizationPass
     // invalidate the results of the analysis. From now on, only small and
     // localized rewrites are allowed, such as replacing a tensor op with its
     // memref equivalent.
-    if (failed(bufferization::insertTensorCopies(getOperation(),
-                                                 bufferizationOptions)))
+    bufferization::BufferizationState bufferizationState;
+
+    if (failed(bufferization::insertTensorCopies(
+            getOperation(), bufferizationOptions, bufferizationState)))
       return signalPassFailure();
 
     // Option `testAnalysisOnly` is a debug/testing flag. If set, the results of
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 630e970cd4b19..154f12b31fc70 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -51,10 +51,11 @@ struct CastOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto castOp = cast<tensor::CastOp>(op);
     auto maybeSrcBufferType = bufferization::getBufferType(
-        castOp.getSource(), options, invocationStack);
+        castOp.getSource(), options, state, invocationStack);
     if (failed(maybeSrcBufferType))
       return failure();
     Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
@@ -89,13 +90,13 @@ struct CastOpInterface
 
     // The result buffer still has the old (pre-cast) type.
     FailureOr<Value> resultBuffer =
-        getBuffer(rewriter, castOp.getSource(), options);
+        getBuffer(rewriter, castOp.getSource(), options, state);
     if (failed(resultBuffer))
       return failure();
 
     // Compute the new type.
     auto resultMemRefType =
-        bufferization::getBufferType(castOp.getResult(), options);
+        bufferization::getBufferType(castOp.getResult(), options, state);
     if (failed(resultMemRefType))
       return failure();
     if (resultBuffer->getType() == *resultMemRefType) {
@@ -141,10 +142,11 @@ struct CollapseShapeOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
     auto maybeSrcBufferType = bufferization::getBufferType(
-        collapseShapeOp.getSrc(), options, invocationStack);
+        collapseShapeOp.getSrc(), options, state, invocationStack);
     if (failed(maybeSrcBufferType))
       return failure();
     auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
@@ -168,7 +170,7 @@ struct CollapseShapeOpInterface
     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
     FailureOr<Value> maybeBuffer =
-        getBuffer(rewriter, collapseShapeOp.getSrc(), options);
+        getBuffer(rewriter, collapseShapeOp.getSrc(), options, state);
     if (failed(maybeBuffer))
       return failure();
     Value buffer = *maybeBuffer;
@@ -210,7 +212,7 @@ struct CollapseShapeOpInterface
       // TODO: Create alloc_tensor ops during TensorCopyInsertion.
       AnalysisState analysisState(options);
       FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
-          rewriter, op->getLoc(), collapseShapeOp.getSrc(), options);
+          rewriter, op->getLoc(), collapseShapeOp.getSrc(), options, state);
       if (failed(tensorAlloc))
         return failure();
       auto memrefType =
@@ -252,7 +254,7 @@ struct DimOpInterface
                           const BufferizationOptions &options,
                           BufferizationState &state) const {
     auto dimOp = cast<tensor::DimOp>(op);
-    FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
+    FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options, state);
     if (failed(v))
       return failure();
     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
@@ -286,7 +288,8 @@ struct EmptyOpInterface
 
     // Allocate a tensor. This emits a "bufferization.alloc_tensor" op.
     FailureOr<Value> allocTensor = allocateTensorForShapedValue(
-        rewriter, op->getLoc(), emptyOp.getResult(), options, /*copy=*/false);
+        rewriter, op->getLoc(), emptyOp.getResult(), options, state,
+        /*copy=*/false);
     if (failed(allocTensor))
       return failure();
     rewriter.replaceOp(op, *allocTensor);
@@ -317,10 +320,11 @@ struct ExpandShapeOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
     auto maybeSrcBufferType = bufferization::getBufferType(
-        expandShapeOp.getSrc(), options, invocationStack);
+        expandShapeOp.getSrc(), options, state, invocationStack);
     if (failed(maybeSrcBufferType))
       return failure();
     auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
@@ -338,7 +342,7 @@ struct ExpandShapeOpInterface
     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
     auto tensorResultType = expandShapeOp.getResultType();
     FailureOr<Value> buffer =
-        getBuffer(rewriter, expandShapeOp.getSrc(), options);
+        getBuffer(rewriter, expandShapeOp.getSrc(), options, state);
     if (failed(buffer))
       return failure();
 
@@ -382,13 +386,13 @@ struct ExtractSliceOpInterface
 
     // Get source buffer.
     FailureOr<Value> srcMemref =
-        getBuffer(rewriter, extractSliceOp.getSource(), options);
+        getBuffer(rewriter, extractSliceOp.getSource(), options, state);
     if (failed(srcMemref))
       return failure();
 
     // Take a subview of the source buffer.
-    auto resultMemrefType =
-        bufferization::getBufferType(extractSliceOp.getResult(), options);
+    auto resultMemrefType = bufferization::getBufferType(
+        extractSliceOp.getResult(), options, state);
     if (failed(resultMemrefType))
       return failure();
     Value subView = rewriter.create<memref::SubViewOp>(
@@ -401,11 +405,12 @@ struct ExtractSliceOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
     assert(value == extractSliceOp.getResult() && "invalid value");
     auto srcMemrefType = bufferization::getBufferType(
-        extractSliceOp.getSource(), options, invocationStack);
+        extractSliceOp.getSource(), options, state, invocationStack);
     if (failed(srcMemrefType))
       return failure();
     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
@@ -442,7 +447,7 @@ struct ExtractOpInterface
                           BufferizationState &state) const {
     auto extractOp = cast<tensor::ExtractOp>(op);
     FailureOr<Value> srcMemref =
-        getBuffer(rewriter, extractOp.getTensor(), options);
+        getBuffer(rewriter, extractOp.getTensor(), options, state);
     if (failed(srcMemref))
       return failure();
     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
@@ -491,12 +496,12 @@ struct FromElementsOpInterface
     auto shape = tensorType.getShape();
     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
-        rewriter, loc, fromElementsOp.getResult(), options,
+        rewriter, loc, fromElementsOp.getResult(), options, state,
         /*copy=*/false);
     if (failed(tensorAlloc))
       return failure();
     FailureOr<BaseMemRefType> memrefType =
-        bufferization::getBufferType(*tensorAlloc, options);
+        bufferization::getBufferType(*tensorAlloc, options, state);
     if (failed(memrefType))
       return failure();
     Value buffer = rewriter.create<bufferization::ToBufferOp>(
@@ -607,7 +612,7 @@ struct GenerateOpInterface
     // Allocate memory.
     Location loc = op->getLoc();
     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
-        rewriter, loc, generateOp.getResult(), options,
+        rewriter, loc, generateOp.getResult(), options, state,
         /*copy=*/false);
     if (failed(tensorAlloc))
       return failure();
@@ -633,7 +638,7 @@ struct InsertOpInterface
                           BufferizationState &state) const {
     auto insertOp = cast<tensor::InsertOp>(op);
     FailureOr<Value> destMemref =
-        getBuffer(rewriter, insertOp.getDest(), options);
+        getBuffer(rewriter, insertOp.getDest(), options, state);
     if (failed(destMemref))
       return failure();
     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
@@ -695,7 +700,7 @@ struct InsertSliceOpInterface
 
     // Get destination buffer.
     FailureOr<Value> dstMemref =
-        getBuffer(rewriter, insertSliceOp.getDest(), options);
+        getBuffer(rewriter, insertSliceOp.getDest(), options, state);
     if (failed(dstMemref))
       return failure();
 
@@ -712,7 +717,7 @@ struct InsertSliceOpInterface
     // Copy tensor. If this tensor.insert_slice has a matching
     // tensor.extract_slice, the copy operation will eventually fold away.
     FailureOr<Value> srcMemref =
-        getBuffer(rewriter, insertSliceOp.getSource(), options);
+        getBuffer(rewriter, insertSliceOp.getSource(), options, state);
     if (failed(srcMemref))
       return failure();
     if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
@@ -749,11 +754,12 @@ struct PadOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     // Infer memory space from the source tensor.
     auto padOp = cast<tensor::PadOp>(op);
     auto maybeSrcBufferType = bufferization::getBufferType(
-        padOp.getSource(), options, invocationStack);
+        padOp.getSource(), options, state, invocationStack);
     if (failed(maybeSrcBufferType))
       return failure();
     MemRefLayoutAttrInterface layout;
@@ -797,9 +803,9 @@ struct PadOpInterface
     }
 
     // Allocate a buffer for the padded result.
-    FailureOr<Value> tensorAlloc =
-        allocateTensorForShapedValue(rewriter, loc, padOp.getResult(), options,
-                                     /*copy=*/false);
+    FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
+        rewriter, loc, padOp.getResult(), options, state,
+        /*copy=*/false);
     if (failed(tensorAlloc))
       return failure();
 
@@ -846,7 +852,8 @@ struct RankOpInterface
                           const BufferizationOptions &options,
                           BufferizationState &state) const {
     auto rankOp = cast<tensor::RankOp>(op);
-    FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
+    FailureOr<Value> v =
+        getBuffer(rewriter, rankOp.getTensor(), options, state);
     if (failed(v))
       return failure();
     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
@@ -885,13 +892,13 @@ struct ReshapeOpInterface
                           BufferizationState &state) const {
     auto reshapeOp = cast<tensor::ReshapeOp>(op);
     FailureOr<Value> srcBuffer =
-        getBuffer(rewriter, reshapeOp.getSource(), options);
+        getBuffer(rewriter, reshapeOp.getSource(), options, state);
     FailureOr<Value> shapeBuffer =
-        getBuffer(rewriter, reshapeOp.getShape(), options);
+        getBuffer(rewriter, reshapeOp.getShape(), options, state);
     if (failed(srcBuffer) || failed(shapeBuffer))
       return failure();
     auto maybeResultMemRefType =
-        bufferization::getBufferType(reshapeOp.getResult(), options);
+        bufferization::getBufferType(reshapeOp.getResult(), options, state);
     if (failed(maybeResultMemRefType))
       return failure();
 
@@ -901,7 +908,7 @@ struct ReshapeOpInterface
     auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
     if (srcType && !srcType.getLayout().isIdentity()) {
       FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
-          rewriter, op->getLoc(), reshapeOp.getSource(), options);
+          rewriter, op->getLoc(), reshapeOp.getSource(), options, state);
       if (failed(tensorAlloc))
         return failure();
       auto memrefType = MemRefType::get(
@@ -920,11 +927,12 @@ struct ReshapeOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto reshapeOp = cast<tensor::ReshapeOp>(op);
     assert(value == reshapeOp.getResult() && "unexpected value provided");
     auto maybeSourceBufferType = bufferization::getBufferType(
-        reshapeOp.getSource(), options, invocationStack);
+        reshapeOp.getSource(), options, state, invocationStack);
     if (failed(maybeSourceBufferType))
       return failure();
     return getMemRefTypeWithStaticIdentityLayout(
@@ -966,11 +974,11 @@ struct ParallelInsertSliceOpInterface
 
     // Get source and destination buffers.
     FailureOr<Value> destBuffer =
-        getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
+        getBuffer(rewriter, parallelInsertSliceOp.getDest(), options, state);
     if (failed(destBuffer))
       return failure();
     FailureOr<Value> srcBuffer =
-        getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
+        getBuffer(rewriter, parallelInsertSliceOp.getSource(), options, state);
     if (failed(srcBuffer))
       return failure();
 
@@ -1016,7 +1024,8 @@ struct ParallelInsertSliceOpInterface
   /// tensor.parallel_insert_slice op has implicit inplace behavior. We
   /// shouldn't create copy to resolve conflict.
   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
-                                 const AnalysisState &state) const {
+                                 const AnalysisState &analysisState,
+                                 BufferizationState &bufferizationState) const {
     return success();
   }
 };
@@ -1038,7 +1047,7 @@ struct SplatOpInterface
     // Allocate memory.
     Location loc = op->getLoc();
     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
-        rewriter, loc, splatOp.getResult(), options,
+        rewriter, loc, splatOp.getResult(), options, state,
         /*copy=*/false);
     if (failed(tensorAlloc))
       return failure();
@@ -1097,7 +1106,7 @@ struct ConcatOpInterface
     // Allocate memory.
     Location loc = op->getLoc();
     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
-        rewriter, loc, concatOp.getResult(), options,
+        rewriter, loc, concatOp.getResult(), options, state,
         /*copy=*/false);
     if (failed(tensorAlloc))
       return failure();
@@ -1147,7 +1156,7 @@ struct ConcatOpInterface
 
     for (auto operand : concatOp.getInputs()) {
       // Get the buffer for the operand.
-      FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options);
+      FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state);
       if (failed(srcBuffer))
         return failure();
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 45b6e7c512947..a94a1d3567573 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -53,7 +53,8 @@ struct TransferReadOpInterface
     auto readOp = cast<vector::TransferReadOp>(op);
     assert(isa<TensorType>(readOp.getShapedType()) &&
            "only tensor types expected");
-    FailureOr<Value> buffer = getBuffer(rewriter, readOp.getBase(), options);
+    FailureOr<Value> buffer =
+        getBuffer(rewriter, readOp.getBase(), options, state);
     if (failed(buffer))
       return failure();
     replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
@@ -112,7 +113,7 @@ struct TransferWriteOpInterface
 
     // Create a new transfer_write on buffer that doesn't have a return value.
     FailureOr<Value> resultBuffer =
-        getBuffer(rewriter, writeOp.getBase(), options);
+        getBuffer(rewriter, writeOp.getBase(), options, state);
     if (failed(resultBuffer))
       return failure();
     rewriter.create<vector::TransferWriteOp>(
@@ -155,7 +156,8 @@ struct GatherOpInterface
     auto gatherOp = cast<vector::GatherOp>(op);
     assert(isa<TensorType>(gatherOp.getBaseType()) &&
            "only tensor types expected");
-    FailureOr<Value> buffer = getBuffer(rewriter, gatherOp.getBase(), options);
+    FailureOr<Value> buffer =
+        getBuffer(rewriter, gatherOp.getBase(), options, state);
     if (failed(buffer))
       return failure();
     replaceOpWithNewBufferizedOp<vector::GatherOp>(
@@ -185,9 +187,11 @@ struct MaskOpInterface
   }
 
   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
-                                 const AnalysisState &state) const {
+                                 const AnalysisState &analysisState,
+                                 BufferizationState &bufferizationState) const {
     auto bufferizableOp = cast<BufferizableOpInterface>(op);
-    if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
+    if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
+            rewriter, analysisState, bufferizationState)))
       return failure();
 
     // TODO: Remove this function when vector.mask bodies can bufferize
@@ -302,7 +306,8 @@ struct YieldOpInterface
     SmallVector<Value> newResults;
     for (Value value : yieldOp.getOperands()) {
       if (isa<TensorType>(value.getType())) {
-        FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
+        FailureOr<Value> maybeBuffer =
+            getBuffer(rewriter, value, options, state);
         if (failed(maybeBuffer))
           return failure();
         newResults.push_back(*maybeBuffer);
diff --git a/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp b/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
index 2991a3c165ee2..dfaebccde7dcc 100644
--- a/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
+++ b/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
@@ -48,7 +48,11 @@ struct TestTensorCopyInsertionPass
       options.defaultMemorySpaceFn =
           [](TensorType t) -> std::optional<Attribute> { return std::nullopt; };
     }
-    if (failed(bufferization::insertTensorCopies(getOperation(), options)))
+
+    bufferization::BufferizationState bufferizationState;
+
+    if (failed(bufferization::insertTensorCopies(getOperation(), options,
+                                                 bufferizationState)))
       signalPassFailure();
   }
 

>From fb6838eaef09776eddf971ae1cc8c796c01c96de Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Mon, 26 May 2025 15:46:37 +0200
Subject: [PATCH 2/3] Fix code format

---
 .../Dialect/Bufferization/IR/UnstructuredControlFlow.h    | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index 7c07f705c8435..00f7799fc18c6 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -34,7 +34,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-    BufferizationState &state,
+                BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     // Note: The user may want to override this function for OpResults in
     // case the bufferized result type is different from the bufferized type of
@@ -82,9 +82,9 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
       if (bufferType == callerType)
         continue;
 
-        // If the computed buffer type does not match the computed buffer type
-        // of the earlier forwarded operands, fall back to a buffer type with a
-        // fully dynamic layout map.
+      // If the computed buffer type does not match the computed buffer type
+      // of the earlier forwarded operands, fall back to a buffer type with a
+      // fully dynamic layout map.
 #ifndef NDEBUG
       if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
         assert(bufferType.hasRank() && callerType.hasRank() &&

>From 3a59029fe4eff85f5d6a601123d6cb27dd9f3065 Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Tue, 27 May 2025 11:27:36 +0200
Subject: [PATCH 3/3] Make references to BufferizationState constant

---
 .../IR/BufferizableOpInterface.h              | 10 +++----
 .../IR/BufferizableOpInterface.td             |  6 ++---
 .../Bufferization/IR/BufferizationOps.td      |  4 +--
 .../IR/UnstructuredControlFlow.h              |  2 +-
 .../Bufferization/Transforms/Transforms.h     |  4 +--
 .../BufferizableOpInterfaceImpl.cpp           |  2 +-
 .../IR/BufferizableOpInterface.cpp            | 13 +++++-----
 .../Bufferization/IR/BufferizationOps.cpp     |  2 +-
 .../FuncBufferizableOpInterfaceImpl.cpp       |  4 +--
 .../Transforms/TensorCopyInsertion.cpp        |  4 +--
 .../BufferizableOpInterfaceImpl.cpp           |  2 +-
 .../BufferizableOpInterfaceImpl.cpp           | 26 ++++++++++---------
 .../BufferizableOpInterfaceImpl.cpp           | 19 +++++++-------
 .../BufferizableOpInterfaceImpl.cpp           |  7 ++---
 14 files changed, 55 insertions(+), 50 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 328d928c9ebdb..adccbef754ec5 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -598,14 +598,14 @@ class BufferizationState {
 FailureOr<Value>
 allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
                              const BufferizationOptions &options,
-                             BufferizationState &state, bool copy = true);
+                             const BufferizationState &state, bool copy = true);
 
 /// Lookup the buffer for the given value. If the value was not bufferized
 /// yet, wrap it in a ToBufferOp. Otherwise, it is the result of a ToTensorOp,
 /// from which the memref operand is returned.
 FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
                            const BufferizationOptions &options,
-                           BufferizationState &state);
+                           const BufferizationState &state);
 
 /// Return the buffer type for a given Value (tensor) after bufferization
 /// without bufferizing any IR.
@@ -617,7 +617,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
 /// This function is a wrapper around BufferizableOpInterface::getBufferType.
 FailureOr<BaseMemRefType> getBufferType(Value value,
                                         const BufferizationOptions &options,
-                                        BufferizationState &state);
+                                        const BufferizationState &state);
 
 /// Return the buffer type for a given Value (tensor) after bufferization
 /// without bufferizing any IR. This function (and not the other overload
@@ -631,7 +631,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
 /// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
 FailureOr<BaseMemRefType> getBufferType(Value value,
                                         const BufferizationOptions &options,
-                                        BufferizationState &state,
+                                        const BufferizationState &state,
                                         SmallVector<Value> &invocationStack);
 
 /// Return "true" if the given op has tensor semantics and should be bufferized.
@@ -712,7 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
 /// places.
 FailureOr<BaseMemRefType>
 defaultGetBufferType(Value value, const BufferizationOptions &options,
-                     BufferizationState &state,
+                     const BufferizationState &state,
                      SmallVector<Value> &invocationStack);
 
 /// This is the default implementation of
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 80f9b72531660..5607df8c96039 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -382,7 +382,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*methodName=*/"resolveConflicts",
         /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
                       "const ::mlir::bufferization::AnalysisState &":$analysisState,
-                      "::mlir::bufferization::BufferizationState &":$bufferizationState),
+                      "const ::mlir::bufferization::BufferizationState &":$bufferizationState),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           auto bufferizableOp =
@@ -524,7 +524,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*methodName=*/"getBufferType",
         /*args=*/(ins "::mlir::Value":$value,
                       "const ::mlir::bufferization::BufferizationOptions &":$options,
-                      "::mlir::bufferization::BufferizationState &":$state,
+                      "const ::mlir::bufferization::BufferizationState &":$state,
                       "::llvm::SmallVector<::mlir::Value> &":$invocationStack),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
@@ -619,7 +619,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
     ::llvm::LogicalResult resolveTensorOpOperandConflicts(
         ::mlir::RewriterBase &rewriter,
         const ::mlir::bufferization::AnalysisState &analysisState,
-        ::mlir::bufferization::BufferizationState &bufferizationState);
+        const ::mlir::bufferization::BufferizationState &bufferizationState);
 
     /// Return `true` if the given OpOperand creates an alias but does neither
     /// read nor write. This implies that `bufferizesToMemoryRead` and
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 0ee4f79144158..3d4dcdee2663b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -112,7 +112,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
 
     FailureOr<BaseMemRefType> getBufferType(
         Value value, const BufferizationOptions &options,
-        BufferizationState &state,
+        const BufferizationState &state,
         SmallVector<Value> &invocationStack);
 
     RankedTensorType getType() {
@@ -472,7 +472,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
     FailureOr<BaseMemRefType> getBufferType(
         Value value, const BufferizationOptions &options,
-        BufferizationState &state, SmallVector<Value> &invocationStack) {
+        const BufferizationState &state, SmallVector<Value> &invocationStack) {
       return ::llvm::cast<BaseMemRefType>(getMemref().getType());
     }
   }];
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index 00f7799fc18c6..a441b8b66659e 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -34,7 +34,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                BufferizationState &state,
+                const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     // Note: The user may want to override this function for OpResults in
     // case the bufferized result type is different from the bufferized type of
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index e587753ddebee..e17d5264a1a45 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -75,7 +75,7 @@ void hoistBuffersFromLoops(Operation *op);
 /// additional buffer allocations.
 LogicalResult insertTensorCopies(Operation *op,
                                  const OneShotBufferizationOptions &options,
-                                 BufferizationState &bufferizationState,
+                                 const BufferizationState &bufferizationState,
                                  BufferizationStatistics *statistics = nullptr);
 
 /// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
@@ -83,7 +83,7 @@ LogicalResult insertTensorCopies(Operation *op,
 /// additional buffer allocations.
 LogicalResult insertTensorCopies(Operation *op,
                                  const AnalysisState &analysisState,
-                                 BufferizationState &bufferizationState);
+                                 const BufferizationState &bufferizationState);
 
 /// Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor
 /// ops.
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 0389a984e169c..a57d58ab28d28 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -183,7 +183,7 @@ struct SelectOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                BufferizationState &state,
+                const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto selectOp = cast<arith::SelectOp>(op);
     assert(value == selectOp.getResult() && "invalid value");
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 7d67d4a33ac32..1d6e1bdaf80f5 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -165,7 +165,8 @@ Operation *bufferization::getOwnerOfValue(Value value) {
 /// allocated.
 FailureOr<Value> bufferization::allocateTensorForShapedValue(
     OpBuilder &b, Location loc, Value shapedValue,
-    const BufferizationOptions &options, BufferizationState &state, bool copy) {
+    const BufferizationOptions &options, const BufferizationState &state,
+    bool copy) {
   Value tensor;
   if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
     tensor = shapedValue;
@@ -224,7 +225,7 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
 
 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
     RewriterBase &rewriter, const AnalysisState &analysisState,
-    BufferizationState &bufferizationState) {
+    const BufferizationState &bufferizationState) {
   OpBuilder::InsertionGuard g(rewriter);
   Operation *op = getOperation();
   SmallVector<OpOperand *> outOfPlaceOpOperands;
@@ -670,7 +671,7 @@ static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
 
 FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
                                           const BufferizationOptions &options,
-                                          BufferizationState &state) {
+                                          const BufferizationState &state) {
 #ifndef NDEBUG
   auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
   assert(tensorType && "unexpected non-tensor type");
@@ -695,7 +696,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
 /// Return the buffer type for a given Value (tensor) after bufferization.
 FailureOr<BaseMemRefType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options,
-                             BufferizationState &state) {
+                             const BufferizationState &state) {
   SmallVector<Value> invocationStack;
   return getBufferType(value, options, state, invocationStack);
 }
@@ -703,7 +704,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
 /// Return the buffer type for a given Value (tensor) after bufferization.
 FailureOr<BaseMemRefType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options,
-                             BufferizationState &state,
+                             const BufferizationState &state,
                              SmallVector<Value> &invocationStack) {
   assert(llvm::isa<TensorType>(value.getType()) &&
          "unexpected non-tensor type");
@@ -951,7 +952,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
 
 FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     Value value, const BufferizationOptions &options,
-    BufferizationState &bufferizationState,
+    const BufferizationState &bufferizationState,
     SmallVector<Value> &invocationStack) {
   assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 41b86437e11cf..dc54ac94aed32 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -224,7 +224,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
 
 FailureOr<BaseMemRefType>
 AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
-                             BufferizationState &state,
+                             const BufferizationState &state,
                              SmallVector<Value> &invocationStack) {
   assert(value == getResult() && "invalid value");
 
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 9a0a85a71debd..a0168da44b7b3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -213,7 +213,7 @@ struct CallOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                BufferizationState &state,
+                const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto callOp = cast<func::CallOp>(op);
 
@@ -398,7 +398,7 @@ struct FuncOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                BufferizationState &state,
+                const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto funcOp = cast<FuncOp>(op);
     auto bbArg = cast<BlockArgument>(value);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index d971ed5b0f71c..784d95a5dd22a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -28,7 +28,7 @@ using namespace mlir::bufferization;
 
 LogicalResult mlir::bufferization::insertTensorCopies(
     Operation *op, const OneShotBufferizationOptions &options,
-    BufferizationState &bufferizationState,
+    const BufferizationState &bufferizationState,
     BufferizationStatistics *statistics) {
   OneShotAnalysisState analysisState(op, options);
   // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
@@ -50,7 +50,7 @@ LogicalResult mlir::bufferization::insertTensorCopies(
 
 LogicalResult mlir::bufferization::insertTensorCopies(
     Operation *op, const AnalysisState &analysisState,
-    BufferizationState &bufferizationState) {
+    const BufferizationState &bufferizationState) {
   IRRewriter rewriter(op->getContext());
 
   // It may be more efficient to walk in pre-order here, but the current
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index ce355e96ee694..9044d89c80bd6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -26,7 +26,7 @@ namespace {
 /// Generic conversion for any DestinationStyleOpInterface on tensors.
 static LogicalResult bufferizeDestinationStyleOpInterface(
     RewriterBase &rewriter, DestinationStyleOpInterface op,
-    const BufferizationOptions &options, BufferizationState &state) {
+    const BufferizationOptions &options, const BufferizationState &state) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(op);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 59c240c62f934..46fa77a7dc4e6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -276,7 +276,7 @@ struct IfOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                BufferizationState &state,
+                const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto ifOp = cast<scf::IfOp>(op);
     auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
@@ -392,7 +392,7 @@ struct IndexSwitchOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                BufferizationState &state,
+                const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto switchOp = cast<scf::IndexSwitchOp>(op);
     assert(value.getDefiningOp() == op && "invalid value");
@@ -524,7 +524,7 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
 /// layout map and a cast must be inserted.
 static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
     Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
-    const BufferizationOptions &options, BufferizationState &state,
+    const BufferizationOptions &options, const BufferizationState &state,
     SmallVector<Value> &invocationStack) {
   // Determine the buffer type of the init_arg.
   auto initArgBufferType =
@@ -653,9 +653,10 @@ struct ForOpInterface
     return true;
   }
 
-  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
-                                 const AnalysisState &analysisState,
-                                 BufferizationState &bufferizationState) const {
+  LogicalResult
+  resolveConflicts(Operation *op, RewriterBase &rewriter,
+                   const AnalysisState &analysisState,
+                   const BufferizationState &bufferizationState) const {
     auto bufferizableOp = cast<BufferizableOpInterface>(op);
     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
             rewriter, analysisState, bufferizationState)))
@@ -708,7 +709,7 @@ struct ForOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                BufferizationState &state,
+                const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto forOp = cast<scf::ForOp>(op);
     assert(getOwnerOfValue(value) == op && "invalid value");
@@ -900,9 +901,10 @@ struct WhileOpInterface
     return true;
   }
 
-  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
-                                 const AnalysisState &analysisState,
-                                 BufferizationState &bufferizationState) const {
+  LogicalResult
+  resolveConflicts(Operation *op, RewriterBase &rewriter,
+                   const AnalysisState &analysisState,
+                   const BufferizationState &bufferizationState) const {
     auto bufferizableOp = cast<BufferizableOpInterface>(op);
     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
             rewriter, analysisState, bufferizationState)))
@@ -1042,7 +1044,7 @@ struct WhileOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                BufferizationState &state,
+                const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto whileOp = cast<scf::WhileOp>(op);
     assert(getOwnerOfValue(value) == op && "invalid value");
@@ -1298,7 +1300,7 @@ struct ForallOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                BufferizationState &state,
+                const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto forallOp = cast<ForallOp>(op);
 
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 154f12b31fc70..4b778b768d136 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -51,7 +51,7 @@ struct CastOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                BufferizationState &state,
+                const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto castOp = cast<tensor::CastOp>(op);
     auto maybeSrcBufferType = bufferization::getBufferType(
@@ -142,7 +142,7 @@ struct CollapseShapeOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                BufferizationState &state,
+                const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
     auto maybeSrcBufferType = bufferization::getBufferType(
@@ -320,7 +320,7 @@ struct ExpandShapeOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                BufferizationState &state,
+                const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
     auto maybeSrcBufferType = bufferization::getBufferType(
@@ -405,7 +405,7 @@ struct ExtractSliceOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                BufferizationState &state,
+                const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
     assert(value == extractSliceOp.getResult() && "invalid value");
@@ -754,7 +754,7 @@ struct PadOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                BufferizationState &state,
+                const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     // Infer memory space from the source tensor.
     auto padOp = cast<tensor::PadOp>(op);
@@ -927,7 +927,7 @@ struct ReshapeOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
-                BufferizationState &state,
+                const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto reshapeOp = cast<tensor::ReshapeOp>(op);
     assert(value == reshapeOp.getResult() && "unexpected value provided");
@@ -1023,9 +1023,10 @@ struct ParallelInsertSliceOpInterface
 
   /// tensor.parallel_insert_slice op has implicit inplace behavior. We
   /// shouldn't create copy to resolve conflict.
-  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
-                                 const AnalysisState &analysisState,
-                                 BufferizationState &bufferizationState) const {
+  LogicalResult
+  resolveConflicts(Operation *op, RewriterBase &rewriter,
+                   const AnalysisState &analysisState,
+                   const BufferizationState &bufferizationState) const {
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index a94a1d3567573..9da051150e409 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -186,9 +186,10 @@ struct MaskOpInterface
     return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
   }
 
-  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
-                                 const AnalysisState &analysisState,
-                                 BufferizationState &bufferizationState) const {
+  LogicalResult
+  resolveConflicts(Operation *op, RewriterBase &rewriter,
+                   const AnalysisState &analysisState,
+                   const BufferizationState &bufferizationState) const {
     auto bufferizableOp = cast<BufferizableOpInterface>(op);
     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
             rewriter, analysisState, bufferizationState)))



More information about the Mlir-commits mailing list