[Mlir-commits] [mlir] ec8628b - [mlir][linalg][bufferize][NFC] Pass BufferizationState into all op interface methods

Matthias Springer llvmlistbot at llvm.org
Wed Dec 15 18:50:14 PST 2021


Author: Matthias Springer
Date: 2021-12-16T11:45:13+09:00
New Revision: ec8628b1d615270e0e86a4efb71c9477dd95b195

URL: https://github.com/llvm/llvm-project/commit/ec8628b1d615270e0e86a4efb71c9477dd95b195
DIFF: https://github.com/llvm/llvm-project/commit/ec8628b1d615270e0e86a4efb71c9477dd95b195.diff

LOG: [mlir][linalg][bufferize][NFC] Pass BufferizationState into all op interface methods

This allows op interface implementations to make decisions based on dialect-specific bufferization state.

This is in preparation of fixing conflict detection of CallOps in ModuleBufferization.

Differential Revision: https://reviews.llvm.org/D115705

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 35955af49efa..891d59b61616 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -179,8 +179,7 @@ enum class BufferRelation {
 /// equivalence classes to support bufferization.
 class BufferizationAliasInfo {
 public:
-  explicit BufferizationAliasInfo(Operation *rootOp,
-                                  const BufferizationOptions &options);
+  explicit BufferizationAliasInfo(Operation *rootOp);
 
   // BufferizationAliasInfo should be passed as a reference.
   BufferizationAliasInfo(const BufferizationAliasInfo &) = delete;
@@ -271,68 +270,6 @@ class BufferizationAliasInfo {
 /// Return `true` if the given value is a BlockArgument of a FuncOp.
 bool isFunctionArgument(Value value);
 
-/// Determine which OpOperand* will alias with `result` if the op is bufferized
-/// in place. Return an empty vector if the op is not bufferizable.
-SmallVector<OpOperand *> getAliasingOpOperand(OpResult result);
-
-/// Determine which OpResult will alias with `opOperand` if the op is bufferized
-/// in place. Return an empty OpResult if the op is not bufferizable.
-OpResult getAliasingOpResult(OpOperand &opOperand);
-
-/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
-/// op is not bufferizable.
-bool bufferizesToMemoryRead(OpOperand &opOperand);
-
-/// Return true if `opOperand` bufferizes to a memory write. Return
-/// `true` if the op is not bufferizable.
-bool bufferizesToMemoryWrite(OpOperand &opOperand);
-
-/// Return true if `opOperand` does neither read nor write but bufferizes to an
-/// alias. Return false if the op is not bufferizable.
-bool bufferizesToAliasOnly(OpOperand &opOperand);
-
-/// Return true if the given value is read by an op that bufferizes to a memory
-/// read. Also takes into account ops that create an alias but do not read by
-/// themselves (e.g., ExtractSliceOp).
-bool isValueRead(Value value);
-
-/// Starting from `value`, follow the use-def chain in reverse, always selecting
-/// the aliasing OpOperands. Find and return Values for which `condition`
-/// evaluates to true. OpOperands of such matching Values are not traversed any
-/// further.
-///
-/// When reaching the end of a chain (BlockArgument or Value without aliasing
-/// OpOperands), also return the last Value of that chain.
-///
-/// Example:
-///
-///                               8
-///                               |
-///   6*         7*         +-----+----+
-///   |          |          |          |
-///   2*         3          4*         5
-///   |          |          |          |
-///   +----------+----------+----------+
-///              |
-///              1
-///
-/// In the above example, Values with a star satisfy the condition. When
-/// starting the traversal from Value 1, the resulting SetVector is:
-/// { 2, 7, 8, 5 }
-llvm::SetVector<Value>
-findValueInReverseUseDefChain(Value value, const BufferizationOptions &options,
-                              std::function<bool(Value)> condition);
-
-/// Find the Value of the last preceding write of a given Value.
-///
-/// Note: Unknown ops are handled conservatively and assumed to be writes.
-/// Furthermore, BlockArguments are also assumed to be writes. There is no
-/// analysis across block boundaries.
-///
-/// Note: When reaching an end of the reverse SSA use-def chain, that value
-/// is returned regardless of whether it is a memory write or not.
-Value findLastPrecedingWrite(Value value, const BufferizationOptions &options);
-
 /// Dialect-specific bufferization state. Analysis/bufferization information
 /// that is specific to ops from a certain dialect can be stored in derived
 /// variants of this struct.
@@ -359,12 +296,74 @@ struct DialectBufferizationState {
 /// * `replaceOp` replaces an op with new values.
 class BufferizationState {
 public:
-  BufferizationState(Operation *op, const BufferizationOptions &options)
-      : aliasInfo(op, options), options(options), builder(op->getContext()) {}
+  BufferizationState(Operation *op, const BufferizationOptions &options);
 
   // BufferizationState should be passed as a reference.
   BufferizationState(const BufferizationState &) = delete;
 
+  /// Determine which OpOperand* will alias with `result` if the op is
+  /// bufferized in place. Return an empty vector if the op is not bufferizable.
+  SmallVector<OpOperand *> getAliasingOpOperand(OpResult result);
+
+  /// Determine which OpResult will alias with `opOperand` if the op is
+  /// bufferized in place. Return an empty OpResult if the op is not
+  /// bufferizable.
+  OpResult getAliasingOpResult(OpOperand &opOperand);
+
+  /// Return true if `opOperand` bufferizes to a memory read. Return `true` if
+  /// the op is not bufferizable.
+  bool bufferizesToMemoryRead(OpOperand &opOperand);
+
+  /// Return true if `opOperand` bufferizes to a memory write. Return true` if
+  /// the op is not bufferizable.
+  bool bufferizesToMemoryWrite(OpOperand &opOperand);
+
+  /// Return true if `opOperand` does neither read nor write but bufferizes to
+  /// an alias. Return false if the op is not bufferizable.
+  bool bufferizesToAliasOnly(OpOperand &opOperand);
+
+  /// Return true if the given value is read by an op that bufferizes to a
+  /// memory read. Also takes into account ops that create an alias but do not
+  /// read by themselves (e.g., ExtractSliceOp).
+  bool isValueRead(Value value);
+
+  /// Starting from `value`, follow the use-def chain in reverse, always
+  /// selecting the aliasing OpOperands. Find and return Values for which
+  /// `condition` evaluates to true. OpOperands of such matching Values are not
+  /// traversed any further.
+  ///
+  /// When reaching the end of a chain (BlockArgument or Value without aliasing
+  /// OpOperands), also return the last Value of that chain.
+  ///
+  /// Example:
+  ///
+  ///                               8
+  ///                               |
+  ///   6*         7*         +-----+----+
+  ///   |          |          |          |
+  ///   2*         3          4*         5
+  ///   |          |          |          |
+  ///   +----------+----------+----------+
+  ///              |
+  ///              1
+  ///
+  /// In the above example, Values with a star satisfy the condition. When
+  /// starting the traversal from Value 1, the resulting SetVector is:
+  /// { 2, 7, 8, 5 }
+  llvm::SetVector<Value>
+  findValueInReverseUseDefChain(Value value,
+                                std::function<bool(Value)> condition);
+
+  /// Find the Value of the last preceding write of a given Value.
+  ///
+  /// Note: Unknown ops are handled conservatively and assumed to be writes.
+  /// Furthermore, BlockArguments are also assumed to be writes. There is no
+  /// analysis across block boundaries.
+  ///
+  /// Note: When reaching an end of the reverse SSA use-def chain, that value
+  /// is returned regardless of whether it is a memory write or not.
+  Value findLastPrecedingWrite(Value value);
+
   /// Creates a memref allocation.
   Optional<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
                               ArrayRef<Value> dynShape);
@@ -494,25 +493,30 @@ template <typename OpTy>
 struct AllocationHoistingBarrierOnly
     : public BufferizableOpInterface::ExternalModel<
           AllocationHoistingBarrierOnly<OpTy>, OpTy> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              BufferizationState &state) const {
     return true;
   }
 
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return false;
   }
 
-  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
-                                                OpResult opResult) const {
+  SmallVector<OpOperand *>
+  getAliasingOpOperand(Operation *op, OpResult opResult,
+                       BufferizationState &state) const {
     return {};
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return OpResult();
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo) const {
+                                const BufferizationAliasInfo &aliasInfo,
+                                BufferizationState &state) const {
     return BufferRelation::None;
   }
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
index a81b52d1433f..df9090972bed 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
@@ -32,7 +32,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         }],
         /*retType=*/"bool",
         /*methodName=*/"bufferizesToMemoryRead",
-        /*args=*/(ins "OpOperand &":$opOperand),
+        /*args=*/(ins "OpOperand &":$opOperand,
+                      "BufferizationState &":$state),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           // Does not have to be implemented for ops without tensor OpOperands.
@@ -60,7 +61,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         }],
         /*retType=*/"bool",
         /*methodName=*/"bufferizesToMemoryWrite",
-        /*args=*/(ins "OpOperand &":$opOperand),
+        /*args=*/(ins "OpOperand &":$opOperand,
+                      "BufferizationState &":$state),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           // Does not have to be implemented for ops without tensor OpOperands.
@@ -82,19 +84,21 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           }],
           /*retType=*/"bool",
           /*methodName=*/"isMemoryWrite",
-          /*args=*/(ins "OpResult":$opResult),
+          /*args=*/(ins "OpResult":$opResult,
+                        "BufferizationState &":$state),
           /*methodBody=*/"",
           /*defaultImplementation=*/[{
             auto bufferizableOp =
                 cast<BufferizableOpInterface>($_op.getOperation());
             SmallVector<OpOperand*> opOperands =
-              bufferizableOp.getAliasingOpOperand(opResult);
+              bufferizableOp.getAliasingOpOperand(opResult, state);
             if (opOperands.empty())
               return true;
             return llvm::any_of(
                 opOperands,
                 [&](OpOperand *operand) {
-                  return bufferizableOp.bufferizesToMemoryWrite(*operand);
+                  return bufferizableOp.bufferizesToMemoryWrite(*operand,
+                                                                state);
                 });
           }]
       >,
@@ -111,7 +115,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         }],
         /*retType=*/"bool",
         /*methodName=*/"mustBufferizeInPlace",
-        /*args=*/(ins "OpResult":$opResult),
+        /*args=*/(ins "OpResult":$opResult,
+                      "BufferizationState &":$state),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return false;
@@ -125,7 +130,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         }],
         /*retType=*/"OpResult",
         /*methodName=*/"getAliasingOpResult",
-        /*args=*/(ins "OpOperand &":$opOperand),
+        /*args=*/(ins "OpOperand &":$opOperand,
+                      "BufferizationState &":$state),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           // Does not have to be implemented for ops without tensor OpOperands.
@@ -148,7 +154,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         }],
         /*retType=*/"SmallVector<OpOperand *>",
         /*methodName=*/"getAliasingOpOperand",
-        /*args=*/(ins "OpResult":$opResult),
+        /*args=*/(ins "OpResult":$opResult,
+                      "BufferizationState &":$state),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           assert(opResult.getType().isa<TensorType>() &&
@@ -159,7 +166,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           for (OpOperand &opOperand : $_op.getOperation()->getOpOperands()) {
             if (!opOperand.get().getType().isa<TensorType>())
               continue;
-            if (bufferizableOp.getAliasingOpResult(opOperand) == opResult)
+            if (bufferizableOp.getAliasingOpResult(opOperand, state) ==
+                    opResult)
               result.push_back(&opOperand);
           }
           return result;
@@ -179,7 +187,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*retType=*/"BufferRelation",
         /*methodName=*/"bufferRelation",
         /*args=*/(ins "OpResult":$opResult,
-                      "const BufferizationAliasInfo &":$aliasInfo),
+                      "const BufferizationAliasInfo &":$aliasInfo,
+                      "BufferizationState &":$state),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           // Does not have to be implemented for ops without tensor OpResults
@@ -282,13 +291,14 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
     /// be called on OpOperands that do not have a tensor type.
     ///
     /// Examples of such ops are `tensor.extract_slice` and `tensor.cast`.
-    bool bufferizesToAliasOnly(OpOperand &opOperand) {
+    bool bufferizesToAliasOnly(OpOperand &opOperand,
+                               BufferizationState &state) {
       auto bufferizableOp =
           cast<BufferizableOpInterface>(getOperation());
-      return !bufferizableOp.bufferizesToMemoryRead(opOperand)
-          && !bufferizableOp.bufferizesToMemoryWrite(opOperand)
+      return !bufferizableOp.bufferizesToMemoryRead(opOperand, state)
+          && !bufferizableOp.bufferizesToMemoryWrite(opOperand, state)
           && static_cast<bool>(
-              bufferizableOp.getAliasingOpResult(opOperand));
+              bufferizableOp.getAliasingOpResult(opOperand, state));
     }
 
     // TODO: The following two attributes should belong to the tensor dialect.

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index d6d3d28a1022..e2edc9d15267 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -78,8 +78,7 @@ BufferizationOptions::BufferizationOptions()
 // BufferizationAliasInfo
 //===----------------------------------------------------------------------===//
 
-BufferizationAliasInfo::BufferizationAliasInfo(
-    Operation *rootOp, const BufferizationOptions &options) {
+BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
   rootOp->walk([&](Operation *op) {
     for (Value v : op->getResults())
       if (v.getType().isa<TensorType>())
@@ -90,26 +89,6 @@ BufferizationAliasInfo::BufferizationAliasInfo(
           if (bbArg.getType().isa<TensorType>())
             createAliasInfoEntry(bbArg);
   });
-
-  // Set up alias sets for OpResults that must bufferize in-place. This should
-  // be done before making any other bufferization decisions.
-  rootOp->walk([&](BufferizableOpInterface bufferizableOp) {
-    if (!options.isOpAllowed(bufferizableOp))
-      return WalkResult::skip();
-    for (OpResult opResult : bufferizableOp->getOpResults()) {
-      if (opResult.getType().isa<TensorType>())
-        if (bufferizableOp.mustBufferizeInPlace(opResult)) {
-          SmallVector<OpOperand *> operands =
-              bufferizableOp.getAliasingOpOperand(opResult);
-          assert(!operands.empty() &&
-                 "expected that OpResult has aliasing OpOperand");
-          for (OpOperand *operand : operands)
-            aliasInfo.unionSets(operand->get(), opResult);
-          markInPlace(opResult);
-        }
-    }
-    return WalkResult::advance();
-  });
 }
 
 /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
@@ -219,30 +198,32 @@ BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
 /// Determine which OpOperand* will alias with `result` if the op is bufferized
 /// in place. Return an empty vector if the op is not bufferizable.
 SmallVector<OpOperand *>
-mlir::linalg::comprehensive_bufferize::getAliasingOpOperand(OpResult result) {
+mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpOperand(
+    OpResult result) {
   if (Operation *op = result.getDefiningOp())
     if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
-      return bufferizableOp.getAliasingOpOperand(result);
+      return bufferizableOp.getAliasingOpOperand(result, *this);
   return {};
 }
 
 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
 /// in place. Return an empty OpResult if the op is not bufferizable.
-OpResult mlir::linalg::comprehensive_bufferize::getAliasingOpResult(
+OpResult
+mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpResult(
     OpOperand &opOperand) {
   if (auto bufferizableOp =
           dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
-    return bufferizableOp.getAliasingOpResult(opOperand);
+    return bufferizableOp.getAliasingOpResult(opOperand, *this);
   return OpResult();
 }
 
 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
 /// op is not bufferizable.
-bool mlir::linalg::comprehensive_bufferize::bufferizesToMemoryRead(
-    OpOperand &opOperand) {
+bool mlir::linalg::comprehensive_bufferize::BufferizationState::
+    bufferizesToMemoryRead(OpOperand &opOperand) {
   if (auto bufferizableOp =
           dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
-    return bufferizableOp.bufferizesToMemoryRead(opOperand);
+    return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
 
   // Unknown op that returns a tensor. The inplace analysis does not support it.
   // Conservatively return true.
@@ -251,11 +232,11 @@ bool mlir::linalg::comprehensive_bufferize::bufferizesToMemoryRead(
 
 /// Return true if `opOperand` bufferizes to a memory write. Return
 /// `true` if the op is not bufferizable.
-bool mlir::linalg::comprehensive_bufferize::bufferizesToMemoryWrite(
-    OpOperand &opOperand) {
+bool mlir::linalg::comprehensive_bufferize::BufferizationState::
+    bufferizesToMemoryWrite(OpOperand &opOperand) {
   if (auto bufferizableOp =
           dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
-    return bufferizableOp.bufferizesToMemoryWrite(opOperand);
+    return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
 
   // Unknown op that returns a tensor. The inplace analysis does not support it.
   // Conservatively return true.
@@ -264,11 +245,11 @@ bool mlir::linalg::comprehensive_bufferize::bufferizesToMemoryWrite(
 
 /// Return true if `opOperand` does neither read nor write but bufferizes to an
 /// alias. Return false if the op is not bufferizable.
-bool mlir::linalg::comprehensive_bufferize::bufferizesToAliasOnly(
-    OpOperand &opOperand) {
+bool mlir::linalg::comprehensive_bufferize::BufferizationState::
+    bufferizesToAliasOnly(OpOperand &opOperand) {
   if (auto bufferizableOp =
           dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
-    return bufferizableOp.bufferizesToAliasOnly(opOperand);
+    return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
 
   // Unknown op that returns a tensor. The inplace analysis does not support it.
   // Conservatively return false.
@@ -278,7 +259,8 @@ bool mlir::linalg::comprehensive_bufferize::bufferizesToAliasOnly(
 /// Return true if the given value is read by an op that bufferizes to a memory
 /// read. Also takes into account ops that create an alias but do not read by
 /// themselves (e.g., ExtractSliceOp).
-bool mlir::linalg::comprehensive_bufferize::isValueRead(Value value) {
+bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead(
+    Value value) {
   SmallVector<OpOperand *> workingSet;
   for (OpOperand &use : value.getUses())
     workingSet.push_back(&use);
@@ -301,9 +283,9 @@ bool mlir::linalg::comprehensive_bufferize::isValueRead(Value value) {
 // evaluates to true. OpOperands of such matching Values are not traversed any
 // further.
 llvm::SetVector<Value>
-mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain(
-    Value value, const BufferizationOptions &options,
-    std::function<bool(Value)> condition) {
+mlir::linalg::comprehensive_bufferize::BufferizationState::
+    findValueInReverseUseDefChain(Value value,
+                                  std::function<bool(Value)> condition) {
   llvm::SetVector<Value> result, workingSet;
   workingSet.insert(value);
 
@@ -329,17 +311,17 @@ mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain(
 }
 
 // Find the Value of the last preceding write of a given Value.
-Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite(
-    Value value, const BufferizationOptions &options) {
+Value mlir::linalg::comprehensive_bufferize::BufferizationState::
+    findLastPrecedingWrite(Value value) {
   SetVector<Value> result =
-      findValueInReverseUseDefChain(value, options, [&](Value value) {
+      findValueInReverseUseDefChain(value, [&](Value value) {
         Operation *op = value.getDefiningOp();
         if (!op)
           return true;
         auto bufferizableOp = options.dynCastBufferizableOp(op);
         if (!bufferizableOp)
           return true;
-        return bufferizableOp.isMemoryWrite(value.cast<OpResult>());
+        return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
       });
 
   // To simplify the analysis, `scf.if` ops are considered memory writes. There
@@ -350,6 +332,30 @@ Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite(
   return result.front();
 }
 
+mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
+    Operation *op, const BufferizationOptions &options)
+    : aliasInfo(op), options(options), builder(op->getContext()) {
+  // Set up alias sets for OpResults that must bufferize in-place. This should
+  // be done before making any other bufferization decisions.
+  op->walk([&](BufferizableOpInterface bufferizableOp) {
+    if (!options.isOpAllowed(bufferizableOp))
+      return WalkResult::skip();
+    for (OpResult opResult : bufferizableOp->getOpResults()) {
+      if (opResult.getType().isa<TensorType>())
+        if (bufferizableOp.mustBufferizeInPlace(opResult, *this)) {
+          SmallVector<OpOperand *> operands =
+              bufferizableOp.getAliasingOpOperand(opResult, *this);
+          assert(!operands.empty() &&
+                 "expected that OpResult has aliasing OpOperand");
+          for (OpOperand *operand : operands)
+            aliasInfo.unionAliasSets(operand->get(), opResult);
+          aliasInfo.markInPlace(opResult);
+        }
+    }
+    return WalkResult::advance();
+  });
+}
+
 /// Return the result buffer (memref) for a given OpResult (tensor). Allocate
 /// a new buffer and copy over data from the existing buffer if out-of-place
 /// bufferization is necessary.
@@ -394,9 +400,9 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
     // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
     // use-def chain, it returns that value, regardless of whether it is a
     // memory write or not.
-    Value lastWrite = findLastPrecedingWrite(operand, options);
+    Value lastWrite = findLastPrecedingWrite(operand);
     if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
-      if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>()))
+      if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), *this))
         skipCopy = true;
     // Do not copy if the copied data is never read.
     if (!isValueRead(result))

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
index 97345350835a..3419a6aa4492 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
@@ -39,12 +39,14 @@ namespace bufferization_ext {
 struct ToMemrefOpInterface
     : public BufferizableOpInterface::ExternalModel<ToMemrefOpInterface,
                                                     bufferization::ToMemrefOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              BufferizationState &state) const {
     // It is unknown whether the resulting MemRef will be read or not.
     return true;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return OpResult();
   }
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index b9dde90e63ee..babbec5493ae 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -162,7 +162,8 @@ static void setInPlaceOpResult(OpResult opResult, bool inPlace) {
 
 /// Return true if opOperand has been decided to bufferize in-place.
 static bool isInplaceMemoryWrite(OpOperand &opOperand,
-                                 const BufferizationAliasInfo &aliasInfo) {
+                                 const BufferizationAliasInfo &aliasInfo,
+                                 BufferizationState &state) {
   // The analysis does not know what happens to the result of a ToMemrefOp, so
   // we assume that it is written to.
   // TODO: This is a conservative implementation. This rule will have to be
@@ -170,11 +171,11 @@ static bool isInplaceMemoryWrite(OpOperand &opOperand,
   if (isa<bufferization::ToMemrefOp>(opOperand.getOwner()))
     return true;
   // OpOperands without an aliasing OpResult do not write.
-  OpResult opResult = getAliasingOpResult(opOperand);
+  OpResult opResult = state.getAliasingOpResult(opOperand);
   if (!opResult)
     return false;
   // OpOperands that do not bufferize to a memory write do not write in-place.
-  if (!bufferizesToMemoryWrite(opOperand))
+  if (!state.bufferizesToMemoryWrite(opOperand))
     return false;
   // Check current bufferization decisions.
   return aliasInfo.isInPlace(opResult);
@@ -209,11 +210,12 @@ static bool aliasesNonWritableBuffer(Value value,
 /// Return true if the buffer to which `operand` would bufferize is equivalent
 /// to some buffer write.
 static bool aliasesInPlaceWrite(Value value,
-                                const BufferizationAliasInfo &aliasInfo) {
+                                const BufferizationAliasInfo &aliasInfo,
+                                BufferizationState &state) {
   bool foundInplaceWrite = false;
   aliasInfo.applyOnAliases(value, [&](Value v) {
     for (auto &use : v.getUses()) {
-      if (isInplaceMemoryWrite(use, aliasInfo)) {
+      if (isInplaceMemoryWrite(use, aliasInfo, state)) {
         foundInplaceWrite = true;
         return;
       }
@@ -295,7 +297,7 @@ static bool hasReadAfterWriteInterference(
     // In the above example, if uRead is the OpOperand of reading_op, lastWrite
     // is %0. Note that operations that create an alias but do not write (such
     // as ExtractSliceOp) are skipped.
-    Value lastWrite = findLastPrecedingWrite(uRead->get(), options);
+    Value lastWrite = state.findLastPrecedingWrite(uRead->get());
 
     // Look for conflicting memory writes. Potential conflicts are writes to an
     // alias that have been decided to bufferize inplace.
@@ -352,7 +354,7 @@ static bool hasReadAfterWriteInterference(
 
       // No conflict if the conflicting write and the last write are the same
       // use.
-      if (getAliasingOpResult(*uConflictingWrite) == lastWrite)
+      if (state.getAliasingOpResult(*uConflictingWrite) == lastWrite)
         continue;
 
       // All requirements are met. Conflict found!
@@ -402,7 +404,7 @@ bool wouldCreateReadAfterWriteInterference(
     bool checkConsistencyOnly = false) {
 #ifndef NDEBUG
   if (result) {
-    SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
+    SmallVector<OpOperand *> opOperands = state.getAliasingOpOperand(result);
     assert(llvm::find(opOperands, &operand) != opOperands.end() &&
            "operand and result do not match");
   } else {
@@ -416,7 +418,7 @@ bool wouldCreateReadAfterWriteInterference(
     aliasInfo.applyOnAliases(root, [&](Value alias) {
       for (auto &use : alias.getUses())
         // Read to a value that aliases root.
-        if (bufferizesToMemoryRead(use))
+        if (state.bufferizesToMemoryRead(use))
           res.insert(&use);
     });
   };
@@ -426,7 +428,7 @@ bool wouldCreateReadAfterWriteInterference(
     aliasInfo.applyOnAliases(root, [&](Value alias) {
       for (auto &use : alias.getUses())
         // Inplace write to a value that aliases root.
-        if (isInplaceMemoryWrite(use, aliasInfo))
+        if (isInplaceMemoryWrite(use, aliasInfo, state))
           res.insert(&use);
     });
   };
@@ -439,7 +441,7 @@ bool wouldCreateReadAfterWriteInterference(
   getAliasingInplaceWrites(usesWrite, operand.get());
   if (result)
     getAliasingInplaceWrites(usesWrite, result);
-  if (!checkConsistencyOnly && bufferizesToMemoryWrite(operand))
+  if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
     usesWrite.insert(&operand);
 
   return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state,
@@ -453,7 +455,7 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
                                     const BufferizationAliasInfo &aliasInfo,
                                     BufferizationState &state) {
 #ifndef NDEBUG
-  SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
+  SmallVector<OpOperand *> opOperands = state.getAliasingOpOperand(opResult);
   assert(llvm::find(opOperands, &opOperand) != opOperands.end() &&
          "operand and result do not match");
 #endif // NDEBUG
@@ -467,9 +469,9 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
     return false;
 
   // This is a problem only if the buffer is written to via some alias.
-  bool hasWrite = aliasesInPlaceWrite(opResult, aliasInfo) ||
-                  aliasesInPlaceWrite(opOperand.get(), aliasInfo) ||
-                  bufferizesToMemoryWrite(opOperand);
+  bool hasWrite = aliasesInPlaceWrite(opResult, aliasInfo, state) ||
+                  aliasesInPlaceWrite(opOperand.get(), aliasInfo, state) ||
+                  state.bufferizesToMemoryWrite(opOperand);
   if (!hasWrite)
     return false;
 
@@ -485,7 +487,7 @@ static LogicalResult bufferizableInPlaceAnalysisImpl(
     OpOperand &operand, OpResult result, BufferizationAliasInfo &aliasInfo,
     BufferizationState &state, const DominanceInfo &domInfo) {
 #ifndef NDEBUG
-  SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
+  SmallVector<OpOperand *> opOperands = state.getAliasingOpOperand(result);
   assert(llvm::find(opOperands, &operand) != opOperands.end() &&
          "operand and result do not match");
 #endif // NDEBUG
@@ -539,7 +541,8 @@ static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
     for (OpOperand &opOperand : op->getOpOperands())
       if (opOperand.get().getType().isa<TensorType>())
         if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
-          if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand))
+          if (OpResult opResult =
+                  bufferizableOp.getAliasingOpResult(opOperand, state))
             if (failed(bufferizableInPlaceAnalysisImpl(
                     opOperand, opResult, aliasInfo, state, domInfo)))
               return failure();
@@ -569,16 +572,16 @@ static LogicalResult inPlaceAnalysis(Operation *op,
 /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
 static void equivalenceAnalysis(SmallVector<Operation *> &ops,
                                 BufferizationAliasInfo &aliasInfo,
-                                const BufferizationOptions &options) {
+                                BufferizationState &state) {
   for (Operation *op : ops)
-    if (auto bufferizableOp = options.dynCastBufferizableOp(op))
+    if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
       for (OpResult opResult : op->getOpResults())
         if (opResult.getType().isa<TensorType>())
           if (aliasInfo.isInPlace(opResult)) {
             SmallVector<OpOperand *> opOperands =
-                bufferizableOp.getAliasingOpOperand(opResult);
+                bufferizableOp.getAliasingOpOperand(opResult, state);
             if (!opOperands.empty())
-              if (bufferizableOp.bufferRelation(opResult, aliasInfo) ==
+              if (bufferizableOp.bufferRelation(opResult, aliasInfo, state) ==
                   BufferRelation::Equivalent)
                 for (OpOperand *opOperand : opOperands)
                   aliasInfo.unionEquivalenceClasses(opResult, opOperand->get());
@@ -589,7 +592,7 @@ static void equivalenceAnalysis(SmallVector<Operation *> &ops,
 /// in `op`.
 static void equivalenceAnalysis(Operation *op,
                                 BufferizationAliasInfo &aliasInfo,
-                                const BufferizationOptions &options) {
+                                BufferizationState &state) {
   // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
   SmallVector<Operation *> ops;
   op->walk<WalkOrder::PostOrder>([&](Operation *op) {
@@ -599,7 +602,7 @@ static void equivalenceAnalysis(Operation *op,
     ops.push_back(op);
   });
 
-  equivalenceAnalysis(ops, aliasInfo, options);
+  equivalenceAnalysis(ops, aliasInfo, state);
 }
 
 /// Assert that the current bufferization decisions are consistent.
@@ -613,7 +616,8 @@ checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
     if (auto bufferizableOp = options.dynCastBufferizableOp(op))
       for (OpOperand &opOperand : op->getOpOperands())
         if (opOperand.get().getType().isa<TensorType>()) {
-          OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand);
+          OpResult opResult =
+              bufferizableOp.getAliasingOpResult(opOperand, state);
           if (wouldCreateReadAfterWriteInterference(
                   opOperand, opResult, domInfo, state, aliasInfo,
                   /*checkConsistencyOnly=*/true)) {
@@ -669,7 +673,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
   if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo,
                              options.analysisFuzzerSeed)))
     return failure();
-  equivalenceAnalysis(op, aliasInfo, options);
+  equivalenceAnalysis(op, aliasInfo, state);
 
   auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) {
     for (const std::unique_ptr<PostAnalysisStep> &step : steps) {
@@ -679,7 +683,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
       // Analyze ops that were created by the PostAnalysisStep.
       if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo)))
         return failure();
-      equivalenceAnalysis(newOps, aliasInfo, options);
+      equivalenceAnalysis(newOps, aliasInfo, state);
     }
     return success();
   };

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 9984ae1ad122..158ad6a76343 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -140,18 +140,22 @@ template <typename OpTy>
 struct LinalgOpInterface
     : public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>,
                                                     OpTy> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              BufferizationState &state) const {
     auto genericOp = cast<linalg::LinalgOp>(op);
     return genericOp.payloadUsesValueFromOperand(&opOperand);
   }
 
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     auto bufferizableOp = cast<BufferizableOpInterface>(op);
-    return static_cast<bool>(bufferizableOp.getAliasingOpResult(opOperand));
+    return static_cast<bool>(
+        bufferizableOp.getAliasingOpResult(opOperand, state));
   }
 
-  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
-                                                OpResult opResult) const {
+  SmallVector<OpOperand *>
+  getAliasingOpOperand(Operation *op, OpResult opResult,
+                       BufferizationState &state) const {
     auto genericOp = cast<linalg::LinalgOp>(op);
     DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands())
@@ -160,14 +164,16 @@ struct LinalgOpInterface
     return {};
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     auto genericOp = cast<linalg::LinalgOp>(op);
     DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
     return pairs[&opOperand];
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo) const {
+                                const BufferizationAliasInfo &aliasInfo,
+                                BufferizationState &state) const {
     return BufferRelation::Equivalent;
   }
 
@@ -180,7 +186,8 @@ struct LinalgOpInterface
 struct InitTensorOpInterface
     : public BufferizableOpInterface::ExternalModel<InitTensorOpInterface,
                                                     linalg::InitTensorOp> {
-  bool isMemoryWrite(Operation *op, OpResult opResult) const {
+  bool isMemoryWrite(Operation *op, OpResult opResult,
+                     BufferizationState &state) const {
     // InitTensorOps allocate but do not write.
     return false;
   }
@@ -203,27 +210,32 @@ struct InitTensorOpInterface
 struct TiledLoopOpInterface
     : public BufferizableOpInterface::ExternalModel<TiledLoopOpInterface,
                                                     linalg::TiledLoopOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              BufferizationState &state) const {
     // TiledLoop alone doesn't bufferize to a memory read, one of the uses of
     // its matching bbArg may.
     auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
-    return isValueRead(tiledLoopOp.getTiedBlockArgument(opOperand));
+    return state.isValueRead(tiledLoopOp.getTiedBlockArgument(opOperand));
   }
 
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     // TiledLoop alone doesn't bufferize to a memory write, one of the uses of
     // its matching bbArg may.
     auto bufferizableOp = cast<BufferizableOpInterface>(op);
-    return static_cast<bool>(bufferizableOp.getAliasingOpResult(opOperand));
+    return static_cast<bool>(
+        bufferizableOp.getAliasingOpResult(opOperand, state));
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
     return tiledLoopOp.getTiedOpResult(opOperand);
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo) const {
+                                const BufferizationAliasInfo &aliasInfo,
+                                BufferizationState &state) const {
     return BufferRelation::Equivalent;
   }
 
@@ -331,15 +343,18 @@ struct TiledLoopOpInterface
 struct YieldOpInterface
     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
                                                     linalg::YieldOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              BufferizationState &state) const {
     return true;
   }
 
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return OpResult();
   }
 
@@ -391,7 +406,6 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
         std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
         SmallVector<Operation *> &newOps) {
   OpBuilder b(op->getContext());
-  const BufferizationOptions &options = state.getOptions();
 
   WalkResult status = op->walk([&](Operation *op) {
     for (OpOperand &operand : op->getOpOperands()) {
@@ -400,7 +414,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
         continue;
 
       SetVector<Value> maybeInitTensor =
-          findValueInReverseUseDefChain(operand.get(), options, [&](Value val) {
+          state.findValueInReverseUseDefChain(operand.get(), [&](Value val) {
             // Continue traversal until this function returns true.
             OpResult opResult = val.dyn_cast<OpResult>();
             if (!opResult)
@@ -410,7 +424,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
             // Only equivalent tensors are supported at the moment.
             // TODO: Support cases such as extract_slice(init_tensor).
             SmallVector<OpOperand *> opOperands =
-                getAliasingOpOperand(opResult);
+                state.getAliasingOpOperand(opResult);
             if (!llvm::all_of(opOperands, [&](OpOperand *operand) {
                   return aliasInfo.areEquivalentBufferizedValues(operand->get(),
                                                                  opResult);

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 0e391a9a4f04..49687ccacd3d 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -490,7 +490,8 @@ namespace std_ext {
 
 struct CallOpInterface
     : public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              BufferizationState &state) const {
     // CallOpInterface alone doesn't bufferize to a memory read, one of the uses
     // of the matching bbArg may. It is the responsibility of the caller to
     // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
@@ -498,7 +499,8 @@ struct CallOpInterface
     return true;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     // CallOpInterface is special, it needs to wait for the callee to be
     // bufferized and needs to inspect the BufferAliasInfo object. It can't
     // make a proper determination by itself and needs to be conservative.
@@ -618,15 +620,18 @@ struct CallOpInterface
 struct ReturnOpInterface
     : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
                                                     ReturnOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              BufferizationState &state) const {
     return true;
   }
 
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return OpResult();
   }
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index edded005a1ee..f3e4aa4d9c98 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -22,8 +22,9 @@ namespace scf_ext {
 struct ExecuteRegionOpInterface
     : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
                                                     scf::ExecuteRegionOp> {
-  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
-                                                OpResult opResult) const {
+  SmallVector<OpOperand *>
+  getAliasingOpOperand(Operation *op, OpResult opResult,
+                       BufferizationState &state) const {
     // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
     // any SSA value that is in scope. To allow for use-def chain traversal
     // through ExecuteRegionOps in the analysis, the corresponding yield value
@@ -39,7 +40,8 @@ struct ExecuteRegionOpInterface
     return {&yieldOp->getOpOperand(resultNum)};
   }
 
-  bool mustBufferizeInPlace(Operation *op, OpResult opResult) const {
+  bool mustBufferizeInPlace(Operation *op, OpResult opResult,
+                            BufferizationState &state) const {
     // ExecuteRegionOp results always bufferize in-place. Since they have no
     // OpOperands, they are mostly ignored by the analysis once alias sets are
     // set up.
@@ -48,7 +50,8 @@ struct ExecuteRegionOpInterface
 
   // TODO: For better bufferization results, this could return `true` only if
   // there is a memory write in the region.
-  bool isMemoryWrite(Operation *op, OpResult opResult) const {
+  bool isMemoryWrite(Operation *op, OpResult opResult,
+                     BufferizationState &state) const {
     // Similar to scf.if, results of this op are always considered memory writes
     // in the analysis. This is a useful pattern for all ops that have tensor
     // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
@@ -71,15 +74,17 @@ struct ExecuteRegionOpInterface
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo) const {
+                                const BufferizationAliasInfo &aliasInfo,
+                                BufferizationState &state) const {
     return BufferRelation::Equivalent;
   }
 };
 
 struct IfOpInterface
     : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
-  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
-                                                OpResult opResult) const {
+  SmallVector<OpOperand *>
+  getAliasingOpOperand(Operation *op, OpResult opResult,
+                       BufferizationState &state) const {
     // IfOps do not have tensor OpOperands. The yielded value can be any SSA
     // value that is in scope. To allow for use-def chain traversal through
     // IfOps in the analysis, both corresponding yield values from the then/else
@@ -95,7 +100,8 @@ struct IfOpInterface
   // there is a memory write in one (or both) of the branches. Since this is not
   // allowed at the moment, we should never encounter scf.ifs that yield
   // unmodified tensors. Such scf.yield ops could just fold away.
-  bool isMemoryWrite(Operation *op, OpResult opResult) const {
+  bool isMemoryWrite(Operation *op, OpResult opResult,
+                     BufferizationState &state) const {
     // IfOp results are always considered memory writes in the analysis. This
     // design decision simplifies the analysis considerably. E.g., consider the
     // following test case:
@@ -121,7 +127,8 @@ struct IfOpInterface
     return true;
   }
 
-  bool mustBufferizeInPlace(Operation *op, OpResult opResult) const {
+  bool mustBufferizeInPlace(Operation *op, OpResult opResult,
+                            BufferizationState &state) const {
     // IfOp results always bufferize in-place. Since they have no OpOperands,
     // they are mostly ignored by the analysis once alias sets are set up.
     return true;
@@ -203,12 +210,13 @@ struct IfOpInterface
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo) const {
+                                const BufferizationAliasInfo &aliasInfo,
+                                BufferizationState &state) const {
     // IfOp results are equivalent to their corresponding yield values if both
     // yield values are equivalent to each other.
     auto bufferizableOp = cast<BufferizableOpInterface>(op);
     SmallVector<OpOperand *> yieldValues =
-        bufferizableOp.getAliasingOpOperand(opResult);
+        bufferizableOp.getAliasingOpOperand(opResult, state);
     assert(yieldValues.size() == 2 && "expected 2 yield values");
     bool equivalentYields = aliasInfo.areEquivalentBufferizedValues(
         yieldValues[0]->get(), yieldValues[1]->get());
@@ -219,21 +227,24 @@ struct IfOpInterface
 struct ForOpInterface
     : public BufferizableOpInterface::ExternalModel<ForOpInterface,
                                                     scf::ForOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              BufferizationState &state) const {
     // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
     // its matching bbArg may.
     auto forOp = cast<scf::ForOp>(op);
-    return isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
+    return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
   }
 
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     // Tensor iter_args of scf::ForOps are always considered as a write. This is
     // to simplify the analysis.
     // TODO: Consider doing sth. like isValueWritten.
     return true;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     auto forOp = cast<scf::ForOp>(op);
     if (!opOperand.get().getType().isa<RankedTensorType>())
       return OpResult();
@@ -241,7 +252,8 @@ struct ForOpInterface
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo) const {
+                                const BufferizationAliasInfo &aliasInfo,
+                                BufferizationState &state) const {
     // ForOp results are equivalent to their corresponding init_args if the
     // corresponding iter_args and yield values are equivalent.
     auto forOp = cast<scf::ForOp>(op);
@@ -410,15 +422,18 @@ LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
 struct YieldOpInterface
     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
                                                     scf::YieldOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              BufferizationState &state) const {
     return true;
   }
 
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return OpResult();
   }
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 7558d792facf..30ca9ed0a78b 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -40,20 +40,24 @@ getTensorBufferizationState(BufferizationState &state) {
 struct CastOpInterface
     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
                                                     tensor::CastOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              BufferizationState &state) const {
     return false;
   }
 
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return op->getResult(0);
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo) const {
+                                const BufferizationAliasInfo &aliasInfo,
+                                BufferizationState &state) const {
     return BufferRelation::Equivalent;
   }
 
@@ -86,15 +90,18 @@ struct CastOpInterface
 struct DimOpInterface
     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
                                                     tensor::DimOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              BufferizationState &state) const {
     return true;
   }
 
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return OpResult();
   }
 
@@ -112,22 +119,26 @@ struct DimOpInterface
 struct ExtractSliceOpInterface
     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
                                                     tensor::ExtractSliceOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              BufferizationState &state) const {
     return false;
   }
 
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return &opOperand == &op->getOpOperand(0) /*source*/
                ? op->getResult(0)
                : OpResult();
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo) const {
+                                const BufferizationAliasInfo &aliasInfo,
+                                BufferizationState &state) const {
     return BufferRelation::None;
   }
 
@@ -160,7 +171,7 @@ struct ExtractSliceOpInterface
     /// If not inplaceable, copy.
     if (!inplace) {
       // Do not copy if the copied data is never read.
-      if (isValueRead(extractSliceOp.result()))
+      if (state.isValueRead(extractSliceOp.result()))
         state.createMemCpy(b, extractSliceOp.getLoc(), subView, alloc);
       subView = alloc;
     }
@@ -173,15 +184,18 @@ struct ExtractSliceOpInterface
 struct ExtractOpInterface
     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
                                                     tensor::ExtractOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              BufferizationState &state) const {
     return true;
   }
 
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return OpResult();
   }
 
@@ -198,22 +212,26 @@ struct ExtractOpInterface
 struct InsertOpInterface
     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
                                                     tensor::InsertOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              BufferizationState &state) const {
     return true;
   }
 
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return true;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
            "expected dest OpOperand");
     return op->getOpResult(0);
   }
 
-  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
-                                                OpResult opResult) const {
+  SmallVector<OpOperand *>
+  getAliasingOpOperand(Operation *op, OpResult opResult,
+                       BufferizationState &state) const {
     return {&op->getOpOperand(1) /*dest*/};
   }
 
@@ -229,7 +247,8 @@ struct InsertOpInterface
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo) const {
+                                const BufferizationAliasInfo &aliasInfo,
+                                BufferizationState &state) const {
     return BufferRelation::Equivalent;
   }
 };
@@ -272,8 +291,8 @@ static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
 /// Return true if `value` is originating from an ExtractSliceOp that matches
 /// the given InsertSliceOp.
 static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
-                                      const BufferizationOptions &options,
-                                      Value value, InsertSliceOp insertOp) {
+                                      BufferizationState &state, Value value,
+                                      InsertSliceOp insertOp) {
   auto condition = [&](Value val) {
     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
       if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
@@ -281,29 +300,33 @@ static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
     return false;
   };
 
-  return llvm::all_of(findValueInReverseUseDefChain(value, options, condition),
+  return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
                       condition);
 }
 
 struct InsertSliceOpInterface
     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
                                                     tensor::InsertSliceOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              BufferizationState &state) const {
     return true;
   }
 
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return &opOperand == &op->getOpOperand(1) /*dest*/;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return &opOperand == &op->getOpOperand(1) /*dest*/
                ? op->getResult(0)
                : OpResult();
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo) const {
+                                const BufferizationAliasInfo &aliasInfo,
+                                BufferizationState &state) const {
     return BufferRelation::Equivalent;
   }
 
@@ -325,8 +348,8 @@ struct InsertSliceOpInterface
 
       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-          hasMatchingExtractSliceOp(aliasInfo, state.getOptions(),
-                                    uConflictingWrite->get(), insertSliceOp))
+          hasMatchingExtractSliceOp(aliasInfo, state, uConflictingWrite->get(),
+                                    insertSliceOp))
         // Case 1: The main insight is that InsertSliceOp reads only part of
         // the destination tensor. The overwritten area is not read. If
         // uConflictingWrite writes into exactly the memory location that is
@@ -343,7 +366,7 @@ struct InsertSliceOpInterface
 
       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-          hasMatchingExtractSliceOp(aliasInfo, state.getOptions(), uRead->get(),
+          hasMatchingExtractSliceOp(aliasInfo, state, uRead->get(),
                                     insertSliceOp))
         // Case 2: The read of the source tensor and the write to the dest
         // tensor via an InsertSliceOp is not a conflict if the read is
@@ -377,8 +400,8 @@ struct InsertSliceOpInterface
       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
           aliasInfo.areEquivalentBufferizedValues(uRead->get(),
                                                   insertSliceOp.source()) &&
-          hasMatchingExtractSliceOp(aliasInfo, state.getOptions(),
-                                    insertSliceOp.source(), insertSliceOp))
+          hasMatchingExtractSliceOp(aliasInfo, state, insertSliceOp.source(),
+                                    insertSliceOp))
         return true;
 
     return false;

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index 3ccfb5065ed2..50ceb5aa77c9 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -20,19 +20,22 @@ namespace vector_ext {
 struct TransferReadOpInterface
     : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
                                                     vector::TransferReadOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              BufferizationState &state) const {
     assert(opOperand.get().getType().isa<RankedTensorType>() &&
            "only tensor types expected");
     return true;
   }
 
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     assert(opOperand.get().getType().isa<RankedTensorType>() &&
            "only tensor types expected");
     return false;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     return OpResult();
   }
 
@@ -56,26 +59,30 @@ struct TransferReadOpInterface
 struct TransferWriteOpInterface
     : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
                                                     vector::TransferWriteOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              BufferizationState &state) const {
     assert(opOperand.get().getType().isa<TensorType>() &&
            "only tensor types expected");
     return true;
   }
 
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     assert(opOperand.get().getType().isa<TensorType>() &&
            "only tensor types expected");
     return true;
   }
 
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                               BufferizationState &state) const {
     assert(opOperand.get().getType().isa<TensorType>() &&
            "only tensor types expected");
     return op->getOpResult(0);
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo) const {
+                                const BufferizationAliasInfo &aliasInfo,
+                                BufferizationState &state) const {
     return BufferRelation::Equivalent;
   }
 


        


More information about the Mlir-commits mailing list