[Mlir-commits] [mlir] [MLIR] Add bufferization state class to OneShotBufferization pass (PR #138143)

Michele Scuttari llvmlistbot at llvm.org
Wed May 21 22:49:22 PDT 2025


https://github.com/mscuttari updated https://github.com/llvm/llvm-project/pull/138143

>From 0f4b5511a12b0e2c60ba5011310f80dd0b314189 Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Thu, 1 May 2025 16:42:35 +0200
Subject: [PATCH 1/7] Add BufferizationState class

---
 .../Bufferization/IR/BufferizableOpInterface.h      | 13 +++++++++++++
 .../Bufferization/IR/BufferizableOpInterface.cpp    |  4 ++++
 2 files changed, 17 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index cb6ef8bc17220..891a5d9044852 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -578,6 +578,19 @@ class AnalysisState {
       insideMutuallyExclusiveRegionsCache;
 };
 
+/// BufferizationState provides information about the state of the IR during the
+/// bufferization process.
+class BufferizationState {
+public:
+  /// Get the cached symbol tables.
+  /// The user is expected to update / invalidate the cached symbol tables if
+  /// the bufferized operation have the Symbol or SymbolTable traits.
+  SymbolTableCollection &getSymbolTables();
+
+private:
+  SymbolTableCollection symbolTables;
+};
+
 /// Create an AllocTensorOp for the given shaped value (memref or tensor).
 /// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
 /// undefined contents is allocated.
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1fc34051680f1..14fa4c1ed8159 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -125,6 +125,10 @@ void AnalysisState::resetCache() {
   insideMutuallyExclusiveRegionsCache.clear();
 }
 
+SymbolTableCollection &BufferizationState::getSymbolTables() {
+  return symbolTables;
+}
+
 Region *bufferization::getNextEnclosingRepetitiveRegion(
     Region *region, const BufferizationOptions &options) {
   assert(isRepetitiveRegion(region, options) && "expected repetitive region");

>From d461cff0cad4722f2977e08ba097d80a64776f8b Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Thu, 1 May 2025 16:43:28 +0200
Subject: [PATCH 2/7] Add BufferizationState as argument to bufferize method

---
 .../IR/BufferizableOpInterface.td             |  3 +-
 .../Bufferization/IR/BufferizationOps.td      | 15 ++++--
 .../Bufferization/Transforms/BufferUtils.h    |  1 +
 .../Bufferization/Transforms/Bufferize.h      |  1 +
 .../Transforms/OneShotAnalysis.h              |  1 +
 .../Transforms/OneShotModuleBufferize.h       |  4 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |  1 +
 .../BufferizableOpInterfaceImpl.cpp           | 12 +++--
 .../Bufferization/IR/BufferizationOps.cpp     | 12 +++--
 .../BufferizationTransformOps.cpp             |  8 +++-
 .../Bufferization/Transforms/BufferUtils.cpp  |  7 +--
 .../Bufferization/Transforms/Bufferize.cpp    | 10 ++--
 .../FuncBufferizableOpInterfaceImpl.cpp       |  9 ++--
 .../Transforms/OneShotAnalysis.cpp            |  9 ++--
 .../Transforms/OneShotModuleBufferize.cpp     | 12 ++---
 .../BufferizableOpInterfaceImpl.cpp           |  3 +-
 .../BufferizableOpInterfaceImpl.cpp           |  7 ++-
 .../Transforms/ConvertToDestinationStyle.cpp  | 25 ++++++----
 .../BufferizableOpInterfaceImpl.cpp           | 17 +++++--
 .../BufferizableOpInterfaceImpl.cpp           | 27 +++++++----
 .../BufferizableOpInterfaceImpl.cpp           |  6 ++-
 .../BufferizableOpInterfaceImpl.cpp           |  3 +-
 .../SparsificationAndBufferizationPass.cpp    |  5 +-
 .../BufferizableOpInterfaceImpl.cpp           | 48 ++++++++++++-------
 .../BufferizableOpInterfaceImpl.cpp           | 15 ++++--
 25 files changed, 175 insertions(+), 86 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 95022d7d665d2..b599a9f053215 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -426,7 +426,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*retType=*/"::llvm::LogicalResult",
         /*methodName=*/"bufferize",
         /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
-                      "const ::mlir::bufferization::BufferizationOptions &":$options),
+                      "const ::mlir::bufferization::BufferizationOptions &":$options,
+                      "::mlir::bufferization::BufferizationState &":$state),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           llvm_unreachable("bufferize not implemented");
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 7a1a701bea6dc..dafa4b9b183f2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -93,7 +93,8 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
 
   let extraClassDeclaration = [{
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options);
+                            const BufferizationOptions &options,
+                            BufferizationState &state);
 
     bool resultBufferizesToMemoryWrite(OpResult opResult,
                                        const AnalysisState &state);
@@ -282,7 +283,8 @@ def Bufferization_MaterializeInDestinationOp
 
   let extraClassDeclaration = [{
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options);
+                            const BufferizationOptions &options,
+                            BufferizationState &state);
 
     bool bufferizesToMemoryRead(OpOperand &opOperand,
                                 const AnalysisState &state);
@@ -375,7 +377,8 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
     }
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options);
+                            const BufferizationOptions &options,
+                            BufferizationState &state);
   }];
 }
 
@@ -458,7 +461,8 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     //===------------------------------------------------------------------===//
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options) const {
+                            const BufferizationOptions &options,
+                            BufferizationState &state) const {
       // to_tensor/to_buffer pairs fold away after bufferization.
       return success();
     }
@@ -550,7 +554,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
     }
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options);
+                            const BufferizationOptions &options,
+                            BufferizationState &state);
   }];
 
   let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index e5f3b6d571f43..adeb52cf9d7e6 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -122,6 +122,7 @@ class BufferPlacementTransformationBase {
 // Globals are created lazily at the top of the enclosing ModuleOp with pretty
 // names. Duplicates are avoided.
 FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
+                                         SymbolTableCollection &symbolTables,
                                          uint64_t alignment,
                                          Attribute memorySpace = {});
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index d5cb8d8eb673c..70e3defee0867 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -45,6 +45,7 @@ struct BufferizationStatistics {
 /// additional buffer copies or set "options.copyBeforeWrite = true". The
 /// general bufferization entry point is `runOneShotBufferize`.
 LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
+                          BufferizationState &bufferizationState,
                           BufferizationStatistics *statistics = nullptr);
 
 /// Bufferize the signature of `block` and its callers (i.e., ops that have the
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 673027f76190d..15189d2c1cb87 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -270,6 +270,7 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
 /// Run One-Shot Bufferize on the given op: Analysis + Bufferization
 LogicalResult
 runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
+                    BufferizationState &state,
                     BufferizationStatistics *statistics = nullptr);
 
 } // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
index 4e5f5e9c730fa..2cf801dd1d951 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
@@ -20,6 +20,7 @@ namespace bufferization {
 struct BufferizationStatistics;
 class OneShotAnalysisState;
 struct OneShotBufferizationOptions;
+class BufferizationState;
 
 /// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
 /// `state`.
@@ -38,6 +39,7 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
 ///   will be inserted only to these FuncOps.
 llvm::LogicalResult
 bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+                  BufferizationState &state,
                   BufferizationStatistics *statistics = nullptr);
 
 /// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
@@ -50,7 +52,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
 llvm::LogicalResult runOneShotModuleBufferize(
     ModuleOp moduleOp,
     const bufferization::OneShotBufferizationOptions &options,
-    BufferizationStatistics *statistics = nullptr);
+    BufferizationState &state, BufferizationStatistics *statistics = nullptr);
 
 } // namespace bufferization
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 4f90fc8831bc6..2eef0a06d0eb4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -30,6 +30,7 @@ namespace mlir {
 namespace bufferization {
 class AllocTensorOp;
 class OneShotAnalysisState;
+class BufferizationState;
 } // namespace bufferization
 
 namespace linalg {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 5e69a98db8f1e..f646326ffc58f 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -24,7 +24,8 @@ struct ConstantOpInterface
     : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
                                                     arith::ConstantOp> {
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto constantOp = cast<arith::ConstantOp>(op);
     auto type = dyn_cast<RankedTensorType>(constantOp.getType());
 
@@ -46,7 +47,8 @@ struct ConstantOpInterface
     // Create global memory segment and replace tensor with memref pointing to
     // that memory segment.
     FailureOr<memref::GlobalOp> globalOp =
-        getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
+        getGlobalFor(constantOp, state.getSymbolTables(),
+                     options.bufferAlignment, memorySpace);
     if (failed(globalOp))
       return failure();
     memref::GlobalOp globalMemref = *globalOp;
@@ -83,7 +85,8 @@ struct IndexCastOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto castOp = cast<arith::IndexCastOp>(op);
     auto resultTensorType = cast<TensorType>(castOp.getType());
 
@@ -131,7 +134,8 @@ struct SelectOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto selectOp = cast<arith::SelectOp>(op);
     Location loc = selectOp.getLoc();
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index ecd2ef15546a4..91eccb0ab7430 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -149,7 +149,8 @@ void mlir::bufferization::populateDynamicDimSizes(
 //===----------------------------------------------------------------------===//
 
 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
-                                       const BufferizationOptions &options) {
+                                       const BufferizationOptions &options,
+                                       BufferizationState &state) {
   OpBuilder::InsertionGuard g(rewriter);
   Location loc = getLoc();
 
@@ -529,7 +530,8 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 
 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
-                                         const BufferizationOptions &options) {
+                                         const BufferizationOptions &options,
+                                         BufferizationState &state) {
   FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
   if (failed(buffer))
     return failure();
@@ -576,7 +578,8 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
 
 LogicalResult
 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
-                                      const BufferizationOptions &options) {
+                                      const BufferizationOptions &options,
+                                      BufferizationState &state) {
   bool tensorDest = isa<TensorType>(getDest().getType());
   Value buffer;
   if (tensorDest) {
@@ -861,7 +864,8 @@ void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
-                                    const BufferizationOptions &options) {
+                                    const BufferizationOptions &options,
+                                    BufferizationState &state) {
   // Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
   (void)foldToBufferToTensorPair(rewriter, *this, options);
   // Note: The return value of `bufferize` indicates whether there was an error
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index a1d7bb995fc73..8bb7942304274 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -83,6 +83,8 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
   }
 
   auto payloadOps = state.getPayloadOps(getTarget());
+  BufferizationState bufferizationState;
+
   for (Operation *target : payloadOps) {
     if (!isa<ModuleOp, FunctionOpInterface>(target))
       return emitSilenceableError() << "expected module or function target";
@@ -90,10 +92,12 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
     if (options.bufferizeFunctionBoundaries) {
       if (!moduleOp)
         return emitSilenceableError() << "expected module target";
-      if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
+      if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
+                                                          bufferizationState)))
         return emitSilenceableError() << "bufferization failed";
     } else {
-      if (failed(bufferization::runOneShotBufferize(target, options)))
+      if (failed(bufferization::runOneShotBufferize(target, options,
+                                                    bufferizationState)))
         return emitSilenceableError() << "bufferization failed";
     }
   }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index c2e90764b1335..bb21f642ac077 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -103,8 +103,9 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
 //===----------------------------------------------------------------------===//
 
 FailureOr<memref::GlobalOp>
-bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
-                            Attribute memorySpace) {
+bufferization::getGlobalFor(arith::ConstantOp constantOp,
+                            SymbolTableCollection &symbolTables,
+                            uint64_t alignment, Attribute memorySpace) {
   auto type = cast<RankedTensorType>(constantOp.getType());
   auto moduleOp = constantOp->getParentOfType<ModuleOp>();
   if (!moduleOp)
@@ -127,7 +128,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
   // Create a builder without an insertion point. We will insert using the
   // symbol table to guarantee unique names.
   OpBuilder globalBuilder(moduleOp.getContext());
-  SymbolTable symbolTable(moduleOp);
+  SymbolTable &symbolTable = symbolTables.getSymbolTable(moduleOp);
 
   // Create a pretty name.
   SmallString<64> buf;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 824b505517119..67f373d912dd4 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -161,10 +161,12 @@ struct OneShotBufferizePass
       return signalPassFailure();
     }
 
+    BufferizationState state;
     BufferizationStatistics statistics;
     ModuleOp moduleOp = getOperation();
     if (opt.bufferizeFunctionBoundaries) {
-      if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
+      if (failed(
+              runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
         signalPassFailure();
         return;
       }
@@ -175,7 +177,7 @@ struct OneShotBufferizePass
                   "'bufferize-function-boundaries'");
         return signalPassFailure();
       }
-      if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
+      if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
         signalPassFailure();
         return;
       }
@@ -275,6 +277,7 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
 
 LogicalResult bufferization::bufferizeOp(Operation *op,
                                          const BufferizationOptions &options,
+                                         BufferizationState &bufferizationState,
                                          BufferizationStatistics *statistics) {
   if (options.copyBeforeWrite) {
     AnalysisState state(options);
@@ -331,7 +334,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
                << "//===-------------------------------------------===//\n"
                << "IR after bufferizing: " << nextOp->getName() << "\n");
     rewriter.setInsertionPoint(nextOp);
-    if (failed(bufferizableOp.bufferize(rewriter, options))) {
+    if (failed(
+            bufferizableOp.bufferize(rewriter, options, bufferizationState))) {
       LLVM_DEBUG(llvm::dbgs()
                  << "failed to bufferize\n"
                  << "//===-------------------------------------------===//\n");
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 755477713668e..080796208bfc1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -239,7 +239,8 @@ struct CallOpInterface
   /// All function arguments are writable. It is the responsibility of the
   /// CallOp to insert buffer copies where necessary.
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     func::CallOp callOp = cast<func::CallOp>(op);
 
     // 1. Compute the result types of the new CallOp.
@@ -349,7 +350,8 @@ struct ReturnOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
 #ifndef NDEBUG
     auto returnOp = cast<func::ReturnOp>(op);
     assert(isa<FuncOp>(returnOp->getParentOp()) &&
@@ -418,7 +420,8 @@ struct FuncOpInterface
   /// All function bbArgs are writable unless they are explicitly marked as
   /// read-only. Callers must insert copies when needed.
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto funcOp = cast<FuncOp>(op);
     FunctionType funcType = funcOp.getFunctionType();
 
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 6e93b36d2d5a2..de820e9c8f8af 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -1365,10 +1365,9 @@ LogicalResult bufferization::analyzeOp(Operation *op,
   return success(!failedAnalysis);
 }
 
-LogicalResult
-bufferization::runOneShotBufferize(Operation *op,
-                                   const OneShotBufferizationOptions &options,
-                                   BufferizationStatistics *statistics) {
+LogicalResult bufferization::runOneShotBufferize(
+    Operation *op, const OneShotBufferizationOptions &options,
+    BufferizationState &state, BufferizationStatistics *statistics) {
   // copy-before-write deactivates the analysis. It cannot be used together with
   // test-analysis-only.
   assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
@@ -1391,5 +1390,5 @@ bufferization::runOneShotBufferize(Operation *op,
 
   // Bufferize the op and its nested ops. If options.copyBeforeWrite is set,
   // a new buffer copy is allocated every time a buffer is written to.
-  return bufferizeOp(op, options, statistics);
+  return bufferizeOp(op, options, state, statistics);
 }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index a025da8635135..90ceea4d69680 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -512,7 +512,7 @@ void mlir::bufferization::removeBufferizationAttributesInModule(
 
 LogicalResult mlir::bufferization::bufferizeModuleOp(
     ModuleOp moduleOp, const OneShotBufferizationOptions &options,
-    BufferizationStatistics *statistics) {
+    BufferizationState &state, BufferizationStatistics *statistics) {
   assert(options.bufferizeFunctionBoundaries &&
          "expected that function boundary bufferization is activated");
   IRRewriter rewriter(moduleOp.getContext());
@@ -548,10 +548,10 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
       // Buffer copies must be inserted before every write.
       OneShotBufferizationOptions updatedOptions = options;
       updatedOptions.copyBeforeWrite = true;
-      if (failed(bufferizeOp(funcOp, updatedOptions, statistics)))
+      if (failed(bufferizeOp(funcOp, updatedOptions, state, statistics)))
         return failure();
     } else {
-      if (failed(bufferizeOp(funcOp, options, statistics)))
+      if (failed(bufferizeOp(funcOp, options, state, statistics)))
         return failure();
     }
 
@@ -565,7 +565,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
     // Functions were already bufferized.
     if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
       continue;
-    if (failed(bufferizeOp(&op, options, statistics)))
+    if (failed(bufferizeOp(&op, options, state, statistics)))
       return failure();
   }
 
@@ -577,7 +577,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
 
 LogicalResult mlir::bufferization::runOneShotModuleBufferize(
     ModuleOp moduleOp, const OneShotBufferizationOptions &options,
-    BufferizationStatistics *statistics) {
+    BufferizationState &state, BufferizationStatistics *statistics) {
   assert(options.bufferizeFunctionBoundaries &&
          "expected that function boundary bufferization is activated");
   assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
@@ -606,7 +606,7 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
   }
   if (options.testAnalysisOnly)
     return success();
-  if (failed(bufferizeModuleOp(moduleOp, options, statistics)))
+  if (failed(bufferizeModuleOp(moduleOp, options, state, statistics)))
     return failure();
   return success();
 }
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp
index 72f4a1a4f4c66..6a1546fb48683 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -43,7 +43,8 @@ struct BranchLikeOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     // The operands of this op are bufferized together with the block signature.
     return success();
   }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index be158af09d398..b6a498a57c036 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -148,7 +148,8 @@ struct LinalgOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     return bufferizeDestinationStyleOpInterface(
         rewriter, cast<DestinationStyleOpInterface>(op), options);
   }
@@ -174,7 +175,8 @@ struct SoftmaxOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto softmaxOp = cast<linalg::SoftmaxOp>(op);
     FailureOr<Value> inputBuffer =
         getBuffer(rewriter, softmaxOp.getInput(), options);
@@ -202,6 +204,7 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
     LinalgOpInterfaceHelper<
 #define GET_OP_LIST
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+
         >::registerOpInterface(ctx);
 
     SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index b1340be04e011..f18a31b97967b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -263,7 +263,11 @@ Value linalg::bufferizeToAllocation(
   assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 &&
          "expected single masked op");
   OpBuilder::InsertionGuard g(rewriter);
+
+  // Should the bufferization options and state be function arguments?
   bufferization::BufferizationOptions bufferizationOptions;
+  bufferization::BufferizationState bufferizationState;
+
   Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator();
   assert(isa<vector::YieldOp>(yieldOp) && "expected yield op terminator");
 
@@ -279,7 +283,7 @@ Value linalg::bufferizeToAllocation(
   // Bufferize terminator.
   rewriter.setInsertionPoint(yieldOp);
   if (failed(cast<bufferization::BufferizableOpInterface>(yieldOp).bufferize(
-          rewriter, bufferizationOptions)))
+          rewriter, bufferizationOptions, bufferizationState)))
     return nullptr;
 
   // Erase dead to_tensor ops inside of the mask op. This is necessary because
@@ -300,8 +304,9 @@ Value linalg::bufferizeToAllocation(
       for (OpOperand &use : result.getUses())
         resultUses.push_back(&use);
   rewriter.setInsertionPoint(maskOp);
-  if (failed(cast<bufferization::BufferizableOpInterface>(maskOp.getOperation())
-                 .bufferize(rewriter, bufferizationOptions)))
+  if (failed(
+          cast<bufferization::BufferizableOpInterface>(maskOp.getOperation())
+              .bufferize(rewriter, bufferizationOptions, bufferizationState)))
     return nullptr;
 
   // Set "restrict" attribute, indicating that no other tensor aliases with
@@ -484,8 +489,11 @@ Value linalg::bufferizeToAllocation(
   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
   if (!bufferizableOp)
     return nullptr;
+
+  // Should the bufferization options and states be function arguments?
   BufferizationOptions bufferizationOptions;
-  AnalysisState state(bufferizationOptions);
+  AnalysisState analysisState(bufferizationOptions);
+  BufferizationState bufferizationState;
 
 #ifndef NDEBUG
   if (!options.bufferizeDestinationOnly) {
@@ -527,7 +535,7 @@ Value linalg::bufferizeToAllocation(
   };
   for (OpResult result : tensorResults) {
     AliasingOpOperandList aliasingOperands =
-        state.getAliasingOpOperands(result);
+        analysisState.getAliasingOpOperands(result);
     for (const AliasingOpOperand &operand : aliasingOperands) {
       addOutOfPlaceOperand(operand.opOperand);
       for (OpOperand &resultUse : result.getUses())
@@ -535,7 +543,7 @@ Value linalg::bufferizeToAllocation(
     }
   }
   for (OpOperand &operand : op->getOpOperands()) {
-    if (!state.bufferizesToMemoryWrite(operand))
+    if (!analysisState.bufferizesToMemoryWrite(operand))
       continue;
     if (!isa<RankedTensorType>(operand.get().getType()))
       continue;
@@ -553,7 +561,7 @@ Value linalg::bufferizeToAllocation(
     Value alloc = createAllocationForTensor(
         rewriter, op->getLoc(), operand->get(), options, memorySpace);
     allocs.push_back(alloc);
-    if (!state.findDefinitions(operand).empty()) {
+    if (!analysisState.findDefinitions(operand).empty()) {
       // Initialize buffer with a copy of the operand data. Not needed if the
       // tensor is uninitialized.
       createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);
@@ -575,7 +583,8 @@ Value linalg::bufferizeToAllocation(
 
   // Bufferize the op.
   rewriter.setInsertionPoint(op);
-  if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions)))
+  if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions,
+                                      bufferizationState)))
     return nullptr;
 
   // Set "restrict" attribute, indicating that no other tensor aliases with
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
index 926d580ac7852..104ec3e1449e5 100644
--- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -52,15 +52,21 @@ struct GlobalOpInterface
   bool hasTensorSemantics(Operation *) const { return true; }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &) const {
+                          const BufferizationOptions &,
+                          BufferizationState &state) const {
     auto globalOp = cast<GlobalOp>(op);
     if (!globalOp.getValue().has_value())
       return globalOp.emitError("global op must have a value");
 
+    SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
+        globalOp->getParentWithTrait<OpTrait::SymbolTable>());
+
+    symbolTable.remove(globalOp);
+
     auto tensorType = cast<TensorType>(globalOp.getType());
     auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
 
-    replaceOpWithNewBufferizedOp<memref::GlobalOp>(
+    auto replacement = replaceOpWithNewBufferizedOp<memref::GlobalOp>(
         rewriter, globalOp, globalOp.getSymName(),
         /*sym_visibility=*/globalOp.getSymVisibilityAttr(),
         /*type=*/cast<MemRefType>(memrefType),
@@ -68,6 +74,7 @@ struct GlobalOpInterface
         /*constant=*/!globalOp.getIsMutable(),
         /*alignment=*/nullptr);
 
+    symbolTable.insert(replacement);
     return success();
   }
 };
@@ -91,7 +98,8 @@ struct GlobalLoadOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &) const {
+                          const BufferizationOptions &,
+                          BufferizationState &state) const {
     auto globalLoadOp = cast<GlobalLoadOp>(op);
 
     auto tensorType = cast<TensorType>(globalLoadOp.getType());
@@ -121,7 +129,8 @@ struct GlobalStoreOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto globalStoreOp = cast<GlobalStoreOp>(op);
 
     auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index d6a9d8f6401f1..3ff1f5c49aece 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -95,7 +95,8 @@ struct ConditionOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto conditionOp = cast<scf::ConditionOp>(op);
     auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
 
@@ -181,7 +182,8 @@ struct ExecuteRegionOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
     auto yieldOp = getUniqueYieldOp(executeRegionOp);
     TypeRange newResultTypes(yieldOp.getResults());
@@ -237,7 +239,8 @@ struct IfOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     OpBuilder::InsertionGuard g(rewriter);
     auto ifOp = cast<scf::IfOp>(op);
 
@@ -347,7 +350,8 @@ struct IndexSwitchOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     OpBuilder::InsertionGuard g(rewriter);
     auto switchOp = cast<scf::IndexSwitchOp>(op);
 
@@ -722,7 +726,8 @@ struct ForOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto forOp = cast<scf::ForOp>(op);
     Block *oldLoopBody = forOp.getBody();
 
@@ -939,7 +944,8 @@ struct WhileOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto whileOp = cast<scf::WhileOp>(op);
 
     // Indices of all bbArgs that have tensor type. These are the ones that
@@ -1144,7 +1150,8 @@ struct YieldOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto yieldOp = cast<scf::YieldOp>(op);
     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
              scf::WhileOp>(yieldOp->getParentOp()))
@@ -1220,7 +1227,8 @@ struct ForallOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     OpBuilder::InsertionGuard guard(rewriter);
     auto forallOp = cast<ForallOp>(op);
     int64_t rank = forallOp.getRank();
@@ -1327,7 +1335,8 @@ struct InParallelOpInterface
     : public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
                                                     InParallelOp> {
   LogicalResult bufferize(Operation *op, RewriterBase &b,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     llvm_unreachable("op does not have any tensor OpOperands / OpResults");
     return failure();
   }
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index 6c3b23937f98f..e8cab76d3c753 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -47,7 +47,8 @@ struct AssumingOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto assumingOp = cast<shape::AssumingOp>(op);
     assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
            "only 1 block supported");
@@ -112,7 +113,8 @@ struct AssumingYieldOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto yieldOp = cast<shape::AssumingYieldOp>(op);
     SmallVector<Value> newResults;
     for (Value value : yieldOp.getOperands()) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 7734d1d258453..f952b68ba7e67 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -30,7 +30,8 @@ template <typename ConcreteModel, typename ConcreteOp>
 struct SparseBufferizableOpInterfaceExternalModel
     : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     return op->emitError(
         "sparse_tensor ops must be bufferized with the sparsifier");
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 6e882a8d0ff30..7c7c64f2aef01 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -114,8 +114,11 @@ class SparsificationAndBufferizationPass
       return false;
     });
 
+    bufferization::BufferizationState bufferizationState;
+
     if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
-                                                updatedOptions)))
+                                                updatedOptions,
+                                                bufferizationState)))
       return failure();
 
     bufferization::removeBufferizationAttributesInModule(getOperation());
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index c0e697292d2a0..ac1e90b9f9b35 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -83,7 +83,8 @@ struct CastOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto castOp = cast<tensor::CastOp>(op);
 
     // The result buffer still has the old (pre-cast) type.
@@ -162,7 +163,8 @@ struct CollapseShapeOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
     FailureOr<Value> maybeBuffer =
@@ -247,7 +249,8 @@ struct DimOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto dimOp = cast<tensor::DimOp>(op);
     FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
     if (failed(v))
@@ -271,7 +274,8 @@ struct EmptyOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto emptyOp = cast<tensor::EmptyOp>(op);
 
     // Optimization: Fold away the op if it has no uses.
@@ -329,7 +333,8 @@ struct ExpandShapeOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
     auto tensorResultType = expandShapeOp.getResultType();
     FailureOr<Value> buffer =
@@ -367,7 +372,8 @@ struct ExtractSliceOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
@@ -432,7 +438,8 @@ struct ExtractOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto extractOp = cast<tensor::ExtractOp>(op);
     FailureOr<Value> srcMemref =
         getBuffer(rewriter, extractOp.getTensor(), options);
@@ -474,7 +481,8 @@ struct FromElementsOpInterface
   bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
     auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
 
@@ -586,7 +594,8 @@ struct GenerateOpInterface
   bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto generateOp = cast<tensor::GenerateOp>(op);
 
     auto type = generateOp.getResult().getType();
@@ -620,7 +629,8 @@ struct InsertOpInterface
     : public DstBufferizableOpInterfaceExternalModel<InsertOpInterface,
                                                      tensor::InsertOp> {
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto insertOp = cast<tensor::InsertOp>(op);
     FailureOr<Value> destMemref =
         getBuffer(rewriter, insertOp.getDest(), options);
@@ -670,7 +680,8 @@ struct InsertSliceOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     // insert_slice ops arise from tiling and bufferizing them out-of-place is
     // generally a deal breaker. When used with loops, this ends up cloning the
     // whole tensor on every single iteration and is a symptom of a
@@ -752,7 +763,8 @@ struct PadOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto padOp = cast<tensor::PadOp>(op);
     Location loc = padOp.getLoc();
     RankedTensorType resultType = padOp.getResultType();
@@ -831,7 +843,8 @@ struct RankOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto rankOp = cast<tensor::RankOp>(op);
     FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
     if (failed(v))
@@ -868,7 +881,8 @@ struct ReshapeOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto reshapeOp = cast<tensor::ReshapeOp>(op);
     FailureOr<Value> srcBuffer =
         getBuffer(rewriter, reshapeOp.getSource(), options);
@@ -940,7 +954,8 @@ struct ParallelInsertSliceOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     OpBuilder::InsertionGuard g(rewriter);
     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
     ParallelCombiningOpInterface parallelCombiningParent =
@@ -1015,7 +1030,8 @@ struct SplatOpInterface
   bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     OpBuilder::InsertionGuard g(rewriter);
     auto splatOp = cast<tensor::SplatOp>(op);
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index b2272c5fda876..45b6e7c512947 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -48,7 +48,8 @@ struct TransferReadOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto readOp = cast<vector::TransferReadOp>(op);
     assert(isa<TensorType>(readOp.getShapedType()) &&
            "only tensor types expected");
@@ -103,7 +104,8 @@ struct TransferWriteOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto writeOp = cast<vector::TransferWriteOp>(op);
     assert(isa<TensorType>(writeOp.getShapedType()) &&
            "only tensor types expected");
@@ -148,7 +150,8 @@ struct GatherOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto gatherOp = cast<vector::GatherOp>(op);
     assert(isa<TensorType>(gatherOp.getBaseType()) &&
            "only tensor types expected");
@@ -202,7 +205,8 @@ struct MaskOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto maskOp = cast<vector::MaskOp>(op);
 
     // Do not bufferize if the masked op is not bufferizable.
@@ -279,7 +283,8 @@ struct YieldOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options) const {
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
     auto yieldOp = cast<vector::YieldOp>(op);
 
     // Only supported as a vector.mask terminator.

>From 07344323181d67c16fe435810e1900fd20e6bf89 Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Sat, 17 May 2025 14:26:21 +0200
Subject: [PATCH 3/7] Add extension mechanism to BufferizationState

---
 .../IR/BufferizableOpInterface.h              | 75 +++++++++++++++++--
 .../Bufferization/Transforms/BufferUtils.h    | 10 +++
 .../BufferizableOpInterfaceImpl.cpp           |  3 +-
 .../IR/BufferizableOpInterface.cpp            |  4 -
 .../Bufferization/Transforms/BufferUtils.cpp  | 39 ++++++++++
 .../BufferizableOpInterfaceImpl.cpp           |  8 +-
 6 files changed, 123 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 891a5d9044852..e2c75b9b230fa 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -582,12 +582,77 @@ class AnalysisState {
 /// bufferization process.
 class BufferizationState {
 public:
-  /// Get the cached symbol tables.
-  /// The user is expected to update / invalidate the cached symbol tables if
-  /// the bufferized operation have the Symbol or SymbolTable traits.
-  SymbolTableCollection &getSymbolTables();
+  /// Base class for BufferizationState extensions that allow BufferizationState
+  /// to contain user-specified information in the state object. The extension
+  /// mechanism of BufferizationState mirrors the one of OneShotAnalysisState.
+  class Extension {
+  public:
+    /// Base virtual destructor.
+    // Out-of-line definition ensures symbols are emitted in a single object
+    // file.
+    virtual ~Extension();
+
+  protected:
+    /// Constructs an extension of the given state object.
+    Extension(BufferizationState &state) : state(state) {}
+
+    /// Provides read-only access to the parent OneShotAnalysisState object.
+    const BufferizationState &getBufferizationState() const { return state; }
+
+  private:
+    /// Back-reference to the state that is being extended.
+    BufferizationState &state;
+  };
 
-private:
+  /// Adds a new Extension of the type specified as template parameter,
+  /// constructing it with the arguments provided. The extension is owned by the
+  /// BufferizationState. It is expected that the state does not already have an
+  /// extension of the same type. Extension constructors are expected to take a
+  /// reference to BufferizationState as first argument, automatically supplied
+  /// by this call.
+  template <typename Ty, typename... Args>
+  Ty &addExtension(Args &&...args) {
+    static_assert(std::is_base_of<Extension, Ty>::value,
+                  "only a class derived from "
+                  "BufferizationState::Extension is allowed");
+    auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
+    auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
+    assert(result.second && "extension already added");
+    return *static_cast<Ty *>(result.first->second.get());
+  }
+
+  /// Returns the extension of the specified type.
+  template <typename Ty>
+  Ty *getExtension() {
+    static_assert(std::is_base_of<Extension, Ty>::value,
+                  "only a class derived from "
+                  "BufferizationState::Extension is allowed");
+    auto iter = extensions.find(TypeID::get<Ty>());
+    if (iter == extensions.end())
+      return nullptr;
+    return static_cast<Ty *>(iter->second.get());
+  }
+
+  /// Returns the extension of the specified type.
+  template <typename Ty>
+  const Ty *getExtension() const {
+    return const_cast<BufferizationState *>(this)->getExtension<Ty>();
+  }
+
+  /// Extensions attached to the state, identified by the TypeID of their type.
+  /// Only one extension of any given type is allowed.
+  DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
+};
+
+/// Extra bufferization state that is required for bufferization of operations
+/// declaring a symbol or a symbol table.
+struct SymbolBufferizationState : public BufferizationState::Extension {
+  SymbolBufferizationState(BufferizationState &state)
+      : BufferizationState::Extension(state) {}
+
+  /// The cached symbol tables.
+  /// The user is expected to update / invalidate the cached symbol tables if
+  /// the bufferized operation has the Symbol or SymbolTable traits.
   SymbolTableCollection symbolTables;
 };
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index adeb52cf9d7e6..da0cbe31b0420 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -29,6 +29,7 @@ class GlobalOp;
 } // namespace memref
 
 namespace bufferization {
+class BufferizationState;
 
 /// A simple analysis that detects allocation operations.
 class BufferPlacementAllocs {
@@ -126,6 +127,15 @@ FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
                                          uint64_t alignment,
                                          Attribute memorySpace = {});
 
+FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp op,
+                                         BufferizationState &state,
+                                         uint64_t alignment,
+                                         Attribute memorySpace);
+
+void removeSymbol(Operation *op, BufferizationState &state);
+
+void insertSymbol(Operation *op, BufferizationState &state);
+
 } // namespace bufferization
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index f646326ffc58f..1eabafaca261a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -47,8 +47,7 @@ struct ConstantOpInterface
     // Create global memory segment and replace tensor with memref pointing to
     // that memory segment.
     FailureOr<memref::GlobalOp> globalOp =
-        getGlobalFor(constantOp, state.getSymbolTables(),
-                     options.bufferAlignment, memorySpace);
+        getGlobalFor(constantOp, state, options.bufferAlignment, memorySpace);
     if (failed(globalOp))
       return failure();
     memref::GlobalOp globalMemref = *globalOp;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 14fa4c1ed8159..1fc34051680f1 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -125,10 +125,6 @@ void AnalysisState::resetCache() {
   insideMutuallyExclusiveRegionsCache.clear();
 }
 
-SymbolTableCollection &BufferizationState::getSymbolTables() {
-  return symbolTables;
-}
-
 Region *bufferization::getNextEnclosingRepetitiveRegion(
     Region *region, const BufferizationOptions &options) {
   assert(isRepetitiveRegion(region, options) && "expected repetitive region");
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index bb21f642ac077..a5aeb2d1ebb08 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -159,3 +159,42 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
   global->moveBefore(&moduleOp.front());
   return global;
 }
+
+namespace mlir::bufferization {
+FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp op,
+                                         BufferizationState &state,
+                                         uint64_t alignment,
+                                         Attribute memorySpace) {
+  if (auto *symbolBufferizationState =
+          state.getExtension<SymbolBufferizationState>()) {
+    // Use the cached symbol tables.
+    return getGlobalFor(op, symbolBufferizationState->symbolTables, alignment,
+                        memorySpace);
+  }
+
+  SymbolTableCollection symbolTables;
+  return getGlobalFor(op, symbolTables, alignment, memorySpace);
+}
+
+void removeSymbol(Operation *op, BufferizationState &state) {
+  if (auto *symbolBufferizationState =
+          state.getExtension<SymbolBufferizationState>()) {
+    SymbolTable &symbolTable =
+        symbolBufferizationState->symbolTables.getSymbolTable(
+            op->getParentWithTrait<OpTrait::SymbolTable>());
+
+    symbolTable.remove(op);
+  }
+}
+
+void insertSymbol(Operation *op, BufferizationState &state) {
+  if (auto *symbolBufferizationState =
+          state.getExtension<SymbolBufferizationState>()) {
+    SymbolTable &symbolTable =
+        symbolBufferizationState->symbolTables.getSymbolTable(
+            op->getParentWithTrait<OpTrait::SymbolTable>());
+
+    symbolTable.insert(op);
+  }
+}
+} // namespace mlir::bufferization
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
index 104ec3e1449e5..a69bc9e5088ae 100644
--- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 
@@ -58,10 +59,7 @@ struct GlobalOpInterface
     if (!globalOp.getValue().has_value())
       return globalOp.emitError("global op must have a value");
 
-    SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
-        globalOp->getParentWithTrait<OpTrait::SymbolTable>());
-
-    symbolTable.remove(globalOp);
+    bufferization::removeSymbol(globalOp, state);
 
     auto tensorType = cast<TensorType>(globalOp.getType());
     auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
@@ -74,7 +72,7 @@ struct GlobalOpInterface
         /*constant=*/!globalOp.getIsMutable(),
         /*alignment=*/nullptr);
 
-    symbolTable.insert(replacement);
+    bufferization::insertSymbol(replacement, state);
     return success();
   }
 };

>From 8bd6a16b6e6552385052c9129dbe5dd5f3034e0a Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Sat, 17 May 2025 15:12:55 +0200
Subject: [PATCH 4/7] Add missing implementation for Extension destructor

---
 mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1fc34051680f1..0da720ad6da28 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -125,6 +125,8 @@ void AnalysisState::resetCache() {
   insideMutuallyExclusiveRegionsCache.clear();
 }
 
+BufferizationState::Extension::~Extension() = default;
+
 Region *bufferization::getNextEnclosingRepetitiveRegion(
     Region *region, const BufferizationOptions &options) {
   assert(isRepetitiveRegion(region, options) && "expected repetitive region");

>From 45e03837c11c263f21805974a205a813bae2b849 Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Sat, 17 May 2025 15:17:01 +0200
Subject: [PATCH 5/7] Add option to enable caching of symbol tables

---
 .../mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h | 6 ++++++
 .../TransformOps/BufferizationTransformOps.cpp              | 5 +++++
 mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp     | 5 +++++
 .../Transforms/SparsificationAndBufferizationPass.cpp       | 5 +++++
 4 files changed, 21 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 15189d2c1cb87..fa6a08320bd60 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -52,6 +52,12 @@ struct OneShotBufferizationOptions : public BufferizationOptions {
   /// `AnalysisHeuristic::Fuzzer`. The fuzzer should be used only with
   /// `testAnalysisOnly = true`.
   unsigned analysisFuzzerSeed = 0;
+
+  /// Enable caching of symbol tables. If enabled, the SymbolBufferizationState
+  /// class is attached to the bufferization state and the user is required to
+  /// keep the cached symbol tables consistent with respect to the performed
+  /// bufferizations.
+  bool cacheSymbolTables = false;
 };
 
 /// State for analysis-enabled bufferization. This class keeps track of alias
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index 8bb7942304274..a6cae1f4dda33 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -85,6 +85,10 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
   auto payloadOps = state.getPayloadOps(getTarget());
   BufferizationState bufferizationState;
 
+  if (options.cacheSymbolTables) {
+    bufferizationState.addExtension<SymbolBufferizationState>();
+  }
+
   for (Operation *target : payloadOps) {
     if (!isa<ModuleOp, FunctionOpInterface>(target))
       return emitSilenceableError() << "expected module or function target";
@@ -166,6 +170,7 @@ class BufferizationTransformDialectExtension
     registerTransformOps<
 #define GET_OP_LIST
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
+
         >();
   }
 };
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 67f373d912dd4..3f094684aa9f8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -162,6 +162,11 @@ struct OneShotBufferizePass
     }
 
     BufferizationState state;
+
+    if (opt.cacheSymbolTables) {
+      state.addExtension<SymbolBufferizationState>();
+    }
+
     BufferizationStatistics statistics;
     ModuleOp moduleOp = getOperation();
     if (opt.bufferizeFunctionBoundaries) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 7c7c64f2aef01..663f5e420b953 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -116,6 +116,11 @@ class SparsificationAndBufferizationPass
 
     bufferization::BufferizationState bufferizationState;
 
+    if (updatedOptions.cacheSymbolTables) {
+      bufferizationState
+          .addExtension<bufferization::SymbolBufferizationState>();
+    }
+
     if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
                                                 updatedOptions,
                                                 bufferizationState)))

>From 019f5b96d1858965cd50bf6667b2d3e24196ec55 Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Sun, 18 May 2025 18:43:02 +0200
Subject: [PATCH 6/7] Remove caching option and separate extension

---
 .../IR/BufferizableOpInterface.h              | 11 +++---
 .../Bufferization/Transforms/BufferUtils.h    |  5 ---
 .../Transforms/OneShotAnalysis.h              |  6 ----
 .../BufferizableOpInterfaceImpl.cpp           |  3 +-
 .../IR/BufferizableOpInterface.cpp            |  4 +++
 .../BufferizationTransformOps.cpp             |  4 ---
 .../Bufferization/Transforms/BufferUtils.cpp  | 35 ++++---------------
 .../Bufferization/Transforms/Bufferize.cpp    |  4 ---
 .../SparsificationAndBufferizationPass.cpp    |  5 ---
 9 files changed, 16 insertions(+), 61 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index e2c75b9b230fa..d644f49573a35 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -639,16 +639,13 @@ class BufferizationState {
     return const_cast<BufferizationState *>(this)->getExtension<Ty>();
   }
 
+  /// Get a reference to the collection of cached symbol tables.
+  SymbolTableCollection &getSymbolTables();
+
+private:
   /// Extensions attached to the state, identified by the TypeID of their type.
   /// Only one extension of any given type is allowed.
   DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
-};
-
-/// Extra bufferization state that is required for bufferization of operations
-/// declaring a symbol or a symbol table.
-struct SymbolBufferizationState : public BufferizationState::Extension {
-  SymbolBufferizationState(BufferizationState &state)
-      : BufferizationState::Extension(state) {}
 
   /// The cached symbol tables.
   /// The user is expected to update / invalidate the cached symbol tables if
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index da0cbe31b0420..c08bd6c436133 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -127,11 +127,6 @@ FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
                                          uint64_t alignment,
                                          Attribute memorySpace = {});
 
-FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp op,
-                                         BufferizationState &state,
-                                         uint64_t alignment,
-                                         Attribute memorySpace);
-
 void removeSymbol(Operation *op, BufferizationState &state);
 
 void insertSymbol(Operation *op, BufferizationState &state);
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index fa6a08320bd60..15189d2c1cb87 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -52,12 +52,6 @@ struct OneShotBufferizationOptions : public BufferizationOptions {
   /// `AnalysisHeuristic::Fuzzer`. The fuzzer should be used only with
   /// `testAnalysisOnly = true`.
   unsigned analysisFuzzerSeed = 0;
-
-  /// Enable caching of symbol tables. If enabled, the SymbolBufferizationState
-  /// class is attached to the bufferization state and the user is required to
-  /// keep the cached symbol tables consistent with respect to the performed
-  /// bufferizations.
-  bool cacheSymbolTables = false;
 };
 
 /// State for analysis-enabled bufferization. This class keeps track of alias
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 1eabafaca261a..f646326ffc58f 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -47,7 +47,8 @@ struct ConstantOpInterface
     // Create global memory segment and replace tensor with memref pointing to
     // that memory segment.
     FailureOr<memref::GlobalOp> globalOp =
-        getGlobalFor(constantOp, state, options.bufferAlignment, memorySpace);
+        getGlobalFor(constantOp, state.getSymbolTables(),
+                     options.bufferAlignment, memorySpace);
     if (failed(globalOp))
       return failure();
     memref::GlobalOp globalMemref = *globalOp;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 0da720ad6da28..d6224b012ac95 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -127,6 +127,10 @@ void AnalysisState::resetCache() {
 
 BufferizationState::Extension::~Extension() = default;
 
+SymbolTableCollection &BufferizationState::getSymbolTables() {
+  return symbolTables;
+}
+
 Region *bufferization::getNextEnclosingRepetitiveRegion(
     Region *region, const BufferizationOptions &options) {
   assert(isRepetitiveRegion(region, options) && "expected repetitive region");
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index a6cae1f4dda33..db1eb20512033 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -85,10 +85,6 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
   auto payloadOps = state.getPayloadOps(getTarget());
   BufferizationState bufferizationState;
 
-  if (options.cacheSymbolTables) {
-    bufferizationState.addExtension<SymbolBufferizationState>();
-  }
-
   for (Operation *target : payloadOps) {
     if (!isa<ModuleOp, FunctionOpInterface>(target))
       return emitSilenceableError() << "expected module or function target";
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index a5aeb2d1ebb08..ff2c83d228dbb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -161,40 +161,17 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
 }
 
 namespace mlir::bufferization {
-FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp op,
-                                         BufferizationState &state,
-                                         uint64_t alignment,
-                                         Attribute memorySpace) {
-  if (auto *symbolBufferizationState =
-          state.getExtension<SymbolBufferizationState>()) {
-    // Use the cached symbol tables.
-    return getGlobalFor(op, symbolBufferizationState->symbolTables, alignment,
-                        memorySpace);
-  }
-
-  SymbolTableCollection symbolTables;
-  return getGlobalFor(op, symbolTables, alignment, memorySpace);
-}
-
 void removeSymbol(Operation *op, BufferizationState &state) {
-  if (auto *symbolBufferizationState =
-          state.getExtension<SymbolBufferizationState>()) {
-    SymbolTable &symbolTable =
-        symbolBufferizationState->symbolTables.getSymbolTable(
-            op->getParentWithTrait<OpTrait::SymbolTable>());
+  SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
+      op->getParentWithTrait<OpTrait::SymbolTable>());
 
-    symbolTable.remove(op);
-  }
+  symbolTable.remove(op);
 }
 
 void insertSymbol(Operation *op, BufferizationState &state) {
-  if (auto *symbolBufferizationState =
-          state.getExtension<SymbolBufferizationState>()) {
-    SymbolTable &symbolTable =
-        symbolBufferizationState->symbolTables.getSymbolTable(
-            op->getParentWithTrait<OpTrait::SymbolTable>());
+  SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
+      op->getParentWithTrait<OpTrait::SymbolTable>());
 
-    symbolTable.insert(op);
-  }
+  symbolTable.insert(op);
 }
 } // namespace mlir::bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 3f094684aa9f8..38de525316f7a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -163,10 +163,6 @@ struct OneShotBufferizePass
 
     BufferizationState state;
 
-    if (opt.cacheSymbolTables) {
-      state.addExtension<SymbolBufferizationState>();
-    }
-
     BufferizationStatistics statistics;
     ModuleOp moduleOp = getOperation();
     if (opt.bufferizeFunctionBoundaries) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 663f5e420b953..7c7c64f2aef01 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -116,11 +116,6 @@ class SparsificationAndBufferizationPass
 
     bufferization::BufferizationState bufferizationState;
 
-    if (updatedOptions.cacheSymbolTables) {
-      bufferizationState
-          .addExtension<bufferization::SymbolBufferizationState>();
-    }
-
     if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
                                                 updatedOptions,
                                                 bufferizationState)))

>From 21006bc58e5befbd5b07286715b36abd3d2bac34 Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Thu, 22 May 2025 07:49:07 +0200
Subject: [PATCH 7/7] Remove extension mechanism from BUfferizationState

---
 .../IR/BufferizableOpInterface.h              | 61 -------------------
 .../IR/BufferizableOpInterface.cpp            |  2 -
 2 files changed, 63 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index d644f49573a35..43c97d57e1834 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -582,71 +582,10 @@ class AnalysisState {
 /// bufferization process.
 class BufferizationState {
 public:
-  /// Base class for BufferizationState extensions that allow BufferizationState
-  /// to contain user-specified information in the state object. The extension
-  /// mechanism of BufferizationState mirrors the one of OneShotAnalysisState.
-  class Extension {
-  public:
-    /// Base virtual destructor.
-    // Out-of-line definition ensures symbols are emitted in a single object
-    // file.
-    virtual ~Extension();
-
-  protected:
-    /// Constructs an extension of the given state object.
-    Extension(BufferizationState &state) : state(state) {}
-
-    /// Provides read-only access to the parent OneShotAnalysisState object.
-    const BufferizationState &getBufferizationState() const { return state; }
-
-  private:
-    /// Back-reference to the state that is being extended.
-    BufferizationState &state;
-  };
-
-  /// Adds a new Extension of the type specified as template parameter,
-  /// constructing it with the arguments provided. The extension is owned by the
-  /// BufferizationState. It is expected that the state does not already have an
-  /// extension of the same type. Extension constructors are expected to take a
-  /// reference to BufferizationState as first argument, automatically supplied
-  /// by this call.
-  template <typename Ty, typename... Args>
-  Ty &addExtension(Args &&...args) {
-    static_assert(std::is_base_of<Extension, Ty>::value,
-                  "only a class derived from "
-                  "BufferizationState::Extension is allowed");
-    auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
-    auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
-    assert(result.second && "extension already added");
-    return *static_cast<Ty *>(result.first->second.get());
-  }
-
-  /// Returns the extension of the specified type.
-  template <typename Ty>
-  Ty *getExtension() {
-    static_assert(std::is_base_of<Extension, Ty>::value,
-                  "only a class derived from "
-                  "BufferizationState::Extension is allowed");
-    auto iter = extensions.find(TypeID::get<Ty>());
-    if (iter == extensions.end())
-      return nullptr;
-    return static_cast<Ty *>(iter->second.get());
-  }
-
-  /// Returns the extension of the specified type.
-  template <typename Ty>
-  const Ty *getExtension() const {
-    return const_cast<BufferizationState *>(this)->getExtension<Ty>();
-  }
-
   /// Get a reference to the collection of cached symbol tables.
   SymbolTableCollection &getSymbolTables();
 
 private:
-  /// Extensions attached to the state, identified by the TypeID of their type.
-  /// Only one extension of any given type is allowed.
-  DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
-
   /// The cached symbol tables.
   /// The user is expected to update / invalidate the cached symbol tables if
   /// the bufferized operation has the Symbol or SymbolTable traits.
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index d6224b012ac95..14fa4c1ed8159 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -125,8 +125,6 @@ void AnalysisState::resetCache() {
   insideMutuallyExclusiveRegionsCache.clear();
 }
 
-BufferizationState::Extension::~Extension() = default;
-
 SymbolTableCollection &BufferizationState::getSymbolTables() {
   return symbolTables;
 }



More information about the Mlir-commits mailing list