[Mlir-commits] [mlir] aeb1c8d - [mlir][linalg][bufferize] Group helpers in BufferizationState

Matthias Springer llvmlistbot at llvm.org
Thu Nov 11 01:28:40 PST 2021


Author: Matthias Springer
Date: 2021-11-11T18:24:13+09:00
New Revision: aeb1c8d0cae8e13cc97481e54ef11867808fe5b8

URL: https://github.com/llvm/llvm-project/commit/aeb1c8d0cae8e13cc97481e54ef11867808fe5b8
DIFF: https://github.com/llvm/llvm-project/commit/aeb1c8d0cae8e13cc97481e54ef11867808fe5b8.diff

LOG: [mlir][linalg][bufferize] Group helpers in BufferizationState

This simplifies the signature of `bufferize`.

Differential Revision: https://reviews.llvm.org/D113388

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 42908bc6c5c2..3c3f601a9385 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -197,6 +197,8 @@ findValueInReverseUseDefChain(Value value,
 /// is returned regardless of whether it is a memory write or not.
 Value findLastPrecedingWrite(Value value);
 
+struct BufferizationState;
+
 /// Callback functions that are used to allocate/deallocate/copy memory buffers.
 /// Comprehensive Bufferize provides default implementations of these functions.
 // TODO: Could be replaced with a "bufferization strategy" object with virtual
@@ -207,8 +209,7 @@ struct AllocationCallbacks {
   using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
   using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
   using CreateAllocDeallocFn =
-      std::function<Value(OpBuilder &, Location, Value,
-                          BufferizationAliasInfo &, AllocationCallbacks &)>;
+      std::function<Value(OpBuilder &, Location, Value, BufferizationState &)>;
 
   AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
                       MemCpyFn copyFn, CreateAllocDeallocFn allocDeallocFn)
@@ -230,13 +231,40 @@ struct AllocationCallbacks {
   CreateAllocDeallocFn createAllocDeallocFn;
 };
 
+/// BufferizationState keeps track of bufferization state and provides access to
+/// the results of the analysis.
+struct BufferizationState {
+  BufferizationState(BufferizationAliasInfo &aliasInfo,
+                     AllocationCallbacks &allocationFns,
+                     BlockAndValueMapping &tensorToBufferMap)
+      : aliasInfo(aliasInfo), allocationFns(allocationFns),
+        tensorToBufferMap(tensorToBufferMap) {}
+
+  /// Map tensor values to memref buffers.
+  void mapBuffer(ValueRange tensors, ValueRange buffers);
+
+  /// Map a tensor value to a memref buffer.
+  void mapBuffer(Value tensor, Value buffer);
+
+  /// Lookup the memref buffer that is associated to the given tensor value.
+  /// Asserts if no buffer is associated.
+  Value lookupBuffer(Value tensor) const;
+
+  /// `aliasInfo` keeps track of aliasing and equivalent values.
+  BufferizationAliasInfo &aliasInfo;
+
+  /// `allocationFns` contains helper functions for creating alloc ops, dealloc
+  /// ops and memcpy ops.
+  AllocationCallbacks &allocationFns;
+
+  /// The mapping of tensors to buffers.
+  BlockAndValueMapping &tensorToBufferMap;
+};
+
 /// Return the result buffer (memref) for a given OpResult (tensor). Allocate
 /// a new buffer and copy over data from the existing buffer if out-of-place
 /// bufferization is necessary.
-Value getResultBuffer(OpBuilder &b, OpResult result,
-                      const BlockAndValueMapping &bvm,
-                      BufferizationAliasInfo &aliasInfo,
-                      AllocationCallbacks allocationFns);
+Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state);
 
 } // namespace comprehensive_bufferize
 } // namespace linalg
@@ -280,9 +308,7 @@ struct AllocationHoistingBarrierOnly
   bool isWritable(Operation *op, Value value) const { return false; }
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo,
-                          AllocationCallbacks &allocationFn) const {
+                          BufferizationState &state) const {
     auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
     if (any_of(op->getOperandTypes(), isaTensor) ||
         any_of(op->getResultTypes(), isaTensor))

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
index 8c8846753304..66e26b3e6e61 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
@@ -160,8 +160,6 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           llvm_unreachable("bufferRelation not implemented");
         }]
       >,
-      // TODO: Simplify method signature: Pass an OpBuilder and a
-      // BufferizationState object.
       InterfaceMethod<
         /*desc=*/[{
           Bufferize this op, i.e., rewrite it into a memref-based equivalent.
@@ -171,9 +169,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*retType=*/"LogicalResult",
         /*methodName=*/"bufferize",
         /*args=*/(ins "OpBuilder &":$b,
-                      "BlockAndValueMapping &":$bvm,
-                      "BufferizationAliasInfo &":$aliasInfo,
-                      "AllocationCallbacks &":$allocationFn),
+                      "BufferizationState &":$state),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           llvm_unreachable("bufferize not implemented");

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
index 35b6e6f2abe2..72a5c700c10a 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
@@ -27,6 +27,8 @@ namespace comprehensive_bufferize {
 // TODO: from some HW description.
 static constexpr int64_t kBufferAlignments = 128;
 
+struct BufferizationState;
+
 /// Analyze the `ops` to determine which OpResults are inplaceable.
 LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
                               BufferizationAliasInfo &aliasInfo,
@@ -55,9 +57,7 @@ std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
 /// `bufferizedFunctionTypes` (resp. `globalCreator`) are expected to be
 /// non-null if `op` is a CallOpInterface (resp. GlobalCreator).
 LogicalResult
-bufferizeOp(Operation *op, BlockAndValueMapping &bvm,
-            BufferizationAliasInfo &aliasInfo,
-            AllocationCallbacks allocationFns,
+bufferizeOp(Operation *op, BufferizationState &state,
             DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes = nullptr);
 
 /// Register external models implemented for the `BufferizableOpInterface`.

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 3c5c95e87a04..bada67c7c1b7 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -7,8 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/IR/AsmState.h"
 #include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
 #include "llvm/Support/Debug.h"
 
 namespace mlir {
@@ -319,30 +322,28 @@ Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite(
 /// a new buffer and copy over data from the existing buffer if out-of-place
 /// bufferization is necessary.
 Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
-    OpBuilder &b, OpResult result, const BlockAndValueMapping &bvm,
-    BufferizationAliasInfo &aliasInfo, AllocationCallbacks allocationFns) {
+    OpBuilder &b, OpResult result, BufferizationState &state) {
   OpBuilder::InsertionGuard guard(b);
   Operation *op = result.getOwner();
   SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
   assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
   OpOperand *opOperand = aliasingOperands.front();
   Value operand = opOperand->get();
-  Value operandBuffer = bvm.lookupOrNull(operand);
-  assert(operandBuffer && "operand buffer not found");
+  Value operandBuffer = state.lookupBuffer(operand);
   // Make sure that all OpOperands are the same buffer. If this is not the case,
   // we would have to materialize a memref value.
   // TODO: Should be looking for checking for "equivalent buffers" instead of
   // operator== here, but equivalent buffers for scf.if yield values are not
   // set up yet.
   if (!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
-        return bvm.lookup(o->get()) == operandBuffer;
+        return state.lookupBuffer(o->get()) == operandBuffer;
       })) {
     op->emitError("result buffer is ambiguous");
     return Value();
   }
 
   // If bufferizing out-of-place, allocate a new buffer.
-  if (!aliasInfo.isInPlace(result)) {
+  if (!state.aliasInfo.isInPlace(result)) {
     // Ops with multiple aliasing operands can currently not bufferize
     // out-of-place.
     assert(
@@ -350,8 +351,8 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
         "ops with multiple aliasing OpOperands cannot bufferize out-of-place");
     Location loc = op->getLoc();
     // Allocate the result buffer.
-    Value resultBuffer = allocationFns.createAllocDeallocFn(
-        b, loc, operand, aliasInfo, allocationFns);
+    Value resultBuffer =
+        state.allocationFns.createAllocDeallocFn(b, loc, operand, state);
     bool skipCopy = false;
     // Do not copy if the last preceding write of `operand` is an op that does
     // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@@ -373,7 +374,7 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
     if (!skipCopy) {
       // Set insertion point now that potential alloc/dealloc are introduced.
       b.setInsertionPoint(op);
-      allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer);
+      state.allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer);
     }
     return resultBuffer;
   }
@@ -381,3 +382,39 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
   // Bufferizing in-place. No need to allocate a new buffer.
   return operandBuffer;
 }
+
+//===----------------------------------------------------------------------===//
+// Bufferization-specific BlockAndValueMapping support with debugging.
+//===----------------------------------------------------------------------===//
+
+/// Wrapper for better debugging.
+void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer(
+    ValueRange tensors, ValueRange buffers) {
+  assert(!tensors.empty() && "unexpected empty tensors");
+  return tensorToBufferMap.map(tensors, buffers);
+}
+
+/// Wrapper for better debugging.
+void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer(
+    Value tensor, Value buffer) {
+  assert(tensor && "unexpected empty tensor");
+  assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
+  return tensorToBufferMap.map(tensor, buffer);
+}
+
+/// Wrapper for better debugging.
+Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
+    Value tensor) const {
+  // TODO: if key comes from bbArg, forward.
+  assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
+  Value v = tensorToBufferMap.lookupOrNull(tensor);
+
+  if (!v) {
+    // Dump tensor for easier debugging.
+    tensor.dump();
+    llvm_unreachable("tensor is not mapped");
+    return Value();
+  }
+
+  return v;
+}

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 233c896648b9..60b0e86d8628 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -172,47 +172,6 @@ static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
   return returnOp;
 }
 
-//===----------------------------------------------------------------------===//
-// Bufferization-specific BlockAndValueMapping support with debugging.
-//===----------------------------------------------------------------------===//
-
-/// Wrapper for better debugging.
-static void map(BlockAndValueMapping &bvm, ValueRange keys, ValueRange values) {
-  assert(!keys.empty() && "Unexpected empty keys");
-  LDBG("\n\tMap: " << printValueInfo(keys.front())
-                   << "\n\tto: " << printValueInfo(values.front()) << '\n');
-  return bvm.map(keys, values);
-}
-
-/// Wrapper for better debugging.
-static void map(BlockAndValueMapping &bvm, Value key, Value value) {
-  LDBG("\n\tMap: " << printValueInfo(key) << "\n\tto: " << printValueInfo(value)
-                   << '\n');
-  return bvm.map(key, value);
-}
-
-/// Wrapper for better debugging.
-static Value lookup(const BlockAndValueMapping &bvm, Value key) {
-  // TODO: if key comes from bbArg, forward.
-  assert(key.getType().isa<TensorType>());
-  Value v = bvm.lookupOrNull(key);
-  if (v)
-    return v;
-
-  Operation *parentOp;
-  if (auto bbArg = key.dyn_cast<BlockArgument>()) {
-    if (isa<FuncOp>(key.getParentBlock()->getParentOp()))
-      parentOp = key.getParentBlock()->getParentOp();
-    else
-      parentOp = key.getParentBlock()->getParentOp()->getParentOfType<FuncOp>();
-  } else {
-    parentOp = key.getDefiningOp()->getParentOfType<FuncOp>();
-  }
-  LDBG("In func:\n" << *parentOp << "\nNO VALUE FOR KEY: " << key << '\n');
-  (void)parentOp;
-  return Value();
-}
-
 //===----------------------------------------------------------------------===//
 // Bufferization-specific attribute manipulation.
 // These are for testing and debugging only. Bufferization information is
@@ -878,8 +837,7 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
 /// `shapedValue.getDefiningOp` (or at the top of the block in case of a
 /// bbArg) and the DeallocOp is at the end of the block.
 static Value createNewAllocDeallocPairForShapedValue(
-    OpBuilder &b, Location loc, Value shapedValue,
-    BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) {
+    OpBuilder &b, Location loc, Value shapedValue, BufferizationState &state) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
@@ -891,19 +849,19 @@ static Value createNewAllocDeallocPairForShapedValue(
   MemRefType allocMemRefType =
       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
   Optional<Value> allocated =
-      allocationFns.allocationFn(b, loc, allocMemRefType, dynShape);
+      state.allocationFns.allocationFn(b, loc, allocMemRefType, dynShape);
   // TODO: For now just assert the value is returned. Eventually need to
   // error-propagate.
   assert(allocated && "allocation failed");
   Value casted = allocated.getValue();
   if (memRefType && memRefType != allocMemRefType) {
     casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
-    aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
+    state.aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
   }
 
   // 2. Create memory deallocation.
   b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
-  allocationFns.deallocationFn(b, loc, allocated.getValue());
+  state.allocationFns.deallocationFn(b, loc, allocated.getValue());
   return casted;
 }
 
@@ -915,8 +873,7 @@ static Value createNewAllocDeallocPairForShapedValue(
 /// inplaceable. For now, it is the responsibility of the `callOp` bufferization
 /// to allow FuncOp that are inplaceable to write inPlace.
 static LogicalResult
-bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
-          BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns,
+bufferize(OpBuilder &b, CallOpInterface callOp, BufferizationState &state,
           DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
   FuncOp funcOp = getCalledFunction(callOp);
   assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
@@ -962,14 +919,13 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
       // If return operand is equivalent to some bbArg, no need to return it.
       Value returnVal = returnOperand.get();
       if (BlockArgument bbArg =
-              getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo)) {
+              getEquivalentEnclosingFuncBBArg(returnVal, state.aliasInfo)) {
         Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
         int64_t idx = bbArg.getArgNumber();
-        Value buffer = lookup(bvm, callOp->getOperand(idx));
-        assert(buffer && "expected bufferized value");
+        Value buffer = state.lookupBuffer(callOp->getOperand(idx));
         // Add CallOp operand/result equivalence: this is interprocedural info.
-        aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
-        map(bvm, oldRes, buffer);
+        state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
+        state.mapBuffer(oldRes, buffer);
         // Add a TensorLoadOp to kill all uses of the CallOp return.
         // Replace all uses of the CallOp results so we can erase the CallOp.
         // This TensorLoadOp must fold/DCE away or bufferization should be
@@ -978,13 +934,13 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
             b.create<memref::TensorLoadOp>(callOp.getLoc(), buffer);
         oldRes.replaceAllUsesWith(tensorLoad);
         // Add new op equivalence info.
-        aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer);
-        map(bvm, tensorLoad, buffer);
+        state.aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer);
+        state.mapBuffer(tensorLoad, buffer);
         continue;
       }
 
       // TODO: Need to hoist above function boundary.
-      if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo)) {
+      if (Operation *allocOp = getEquivalentAlloc(returnVal, state.aliasInfo)) {
         hoistedArguments.push_back(allocOp->getResult(0));
         continue;
       }
@@ -1023,8 +979,7 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
 
     // Tensor operands are guaranteed to have been buferized.
     int64_t idx = opOperand.getOperandNumber();
-    Value buffer = lookup(bvm, tensorOperand);
-    assert(buffer && "expected bufferized value");
+    Value buffer = state.lookupBuffer(tensorOperand);
 
     // Caller / callee type mistmatch is handled with a CastOp.
     auto memRefType = bufferizedFuncType.getInput(idx);
@@ -1037,8 +992,8 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
       Value castBuffer =
           b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
       // Add new op equivalence info.
-      aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
-      map(bvm, tensorOperand, castBuffer);
+      state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
+      state.mapBuffer(tensorOperand, castBuffer);
       buffer = castBuffer;
     }
     newOperands.push_back(buffer);
@@ -1054,9 +1009,7 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
 
 /// FuncOp always creates TensorToMemRef ops.
 static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
-                               BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo,
-                               AllocationCallbacks &allocationFn) {
+                               BufferizationState &state) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPointToStart(&funcOp.body().front());
@@ -1072,8 +1025,8 @@ static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
                           : getContiguousOrUnrankedMemRefType(tensorType);
     Value bufferCast =
         b.create<memref::BufferCastOp>(funcOp.getLoc(), memRefType, bbArg);
-    aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
-    map(bvm, bbArg, bufferCast);
+    state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
+    state.mapBuffer(bbArg, bufferCast);
   }
   return success();
 }
@@ -1230,8 +1183,7 @@ void mlir::linalg::comprehensive_bufferize::defaultMemCpyFn(OpBuilder &b,
 }
 
 LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
-    Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
-    AllocationCallbacks allocationFns,
+    Operation *op, BufferizationState &state,
     DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes) {
   OpBuilder b(op->getContext());
 
@@ -1241,8 +1193,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
     if (!bufferizedFunctionTypes)
       llvm_unreachable(
           "null bufferizedFunctionTypes when bufferizing CallOpInterface");
-    return bufferize(b, callOp, bvm, aliasInfo, allocationFns,
-                     *bufferizedFunctionTypes);
+    return bufferize(b, callOp, state, *bufferizedFunctionTypes);
   }
 
   // Skip BufferCast and TensorLoad ops.
@@ -1251,7 +1202,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
 
   // Bufferize using `BufferizableOpInterface`.
   if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
-    return bufferizableOp.bufferize(b, bvm, aliasInfo, allocationFns);
+    return bufferizableOp.bufferize(b, state);
 
   // Other op with tensors. No bufferization method specified.
   auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
@@ -1262,23 +1213,21 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
 }
 
 static LogicalResult bufferizeFuncOpInternals(
-    FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
-    AllocationCallbacks &allocationFns,
+    FuncOp funcOp, BufferizationState &state,
     DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
   LLVM_DEBUG(llvm::dbgs() << "\n\n");
   LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n');
   OpBuilder b(funcOp->getContext());
 
   // Start by bufferizing `funcOp` arguments.
-  if (failed(bufferize(b, funcOp, bvm, aliasInfo, allocationFns)))
+  if (failed(bufferize(b, funcOp, state)))
     return failure();
 
   // Cannot erase ops during the traversal. Do that afterwards.
   SmallVector<Operation *> toErase;
 
   auto walkFunc = [&](Operation *op) -> WalkResult {
-    if (failed(bufferizeOp(op, bvm, aliasInfo, allocationFns,
-                           &bufferizedFunctionTypes)))
+    if (failed(bufferizeOp(op, state, &bufferizedFunctionTypes)))
       return failure();
 
     // Register post-walk erasure, if necessary.
@@ -1852,9 +1801,10 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     // Bufferization phase.
     if (!options.testAnalysisOnly) {
       BlockAndValueMapping tensorToBufferMap;
-      if (failed(bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo,
-                                          *options.allocationFns,
-                                          bufferizedFunctionTypes)))
+      BufferizationState state(aliasInfo, *options.allocationFns,
+                               tensorToBufferMap);
+      if (failed(
+              bufferizeFuncOpInternals(funcOp, state, bufferizedFunctionTypes)))
         return failure();
     }
   }
@@ -1926,9 +1876,7 @@ struct ConstantOpInterface
   }
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo,
-                          AllocationCallbacks &allocationFn) const {
+                          BufferizationState &state) const {
     auto constantOp = cast<arith::ConstantOp>(op);
     if (!isaTensor(constantOp.getResult().getType()))
       return success();
@@ -1948,8 +1896,8 @@ struct ConstantOpInterface
     auto globalMemref = globalCreator.getGlobalFor(constantOp);
     Value memref = b.create<memref::GetGlobalOp>(
         constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
-    aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
-    map(bvm, constantOp, memref);
+    state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
+    state.mapBuffer(constantOp, memref);
 
     return success();
   }
@@ -1969,10 +1917,10 @@ namespace linalg_ext {
 /// Helper function for LinalgOp bufferization.
 /// When allocating a new buffer, analyze whether `op` wants to read form that
 /// buffer. Only in that case, a copy of the result buffer may be needed.
-static LogicalResult allocateBuffersForResults(
-    OpBuilder &b, Location loc, LinalgOp op,
-    SmallVectorImpl<Value> &resultBuffers, BlockAndValueMapping &bvm,
-    BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) {
+static LogicalResult
+allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
+                          SmallVectorImpl<Value> &resultBuffers,
+                          BufferizationState &state) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(op);
@@ -1983,24 +1931,21 @@ static LogicalResult allocateBuffersForResults(
     OpResult opResult = cast<BufferizableOpInterface>(op.getOperation())
                             .getAliasingOpResult(*opOperand);
     assert(opResult && "could not find correspond OpResult");
-    Value resultBuffer =
-        getResultBuffer(b, opResult, bvm, aliasInfo, allocationFns);
+    Value resultBuffer = getResultBuffer(b, opResult, state);
     if (!resultBuffer)
       return failure();
     resultBuffers.push_back(resultBuffer);
   }
 
   if (op->getNumResults())
-    map(bvm, op->getResults(), resultBuffers);
+    state.mapBuffer(op->getResults(), resultBuffers);
 
   return success();
 }
 
 /// Generic conversion for any LinalgOp on tensors.
 static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
-                                       BlockAndValueMapping &bvm,
-                                       BufferizationAliasInfo &aliasInfo,
-                                       AllocationCallbacks &allocationFns) {
+                                       BufferizationState &state) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
@@ -2017,13 +1962,11 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
       newInputBuffers.push_back(opOperand->get());
       continue;
     }
-    newInputBuffers.push_back(lookup(bvm, opOperand->get()));
-    assert(newInputBuffers.back() && "missing buffer");
+    newInputBuffers.push_back(state.lookupBuffer(opOperand->get()));
   }
   SmallVector<Value> newOutputBuffers;
   // Try to allocate new buffers depending on op's inplace semantics.
-  if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm,
-                                       aliasInfo, allocationFns)))
+  if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, state)))
     return failure();
 
   // Clone the newly bufferized op.
@@ -2036,7 +1979,7 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
 
   // Replace the results of the old op with the new output buffers.
   if (op->getNumResults())
-    map(bvm, op->getResults(), newOutputBuffers);
+    state.mapBuffer(op->getResults(), newOutputBuffers);
 
   // The original op will be DCE'd away later.
 
@@ -2087,11 +2030,8 @@ struct LinalgOpInterface
   }
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo,
-                          AllocationCallbacks &allocationFn) const {
-    return bufferizeLinalgOp(b, cast<LinalgOp>(op), bvm, aliasInfo,
-                             allocationFn);
+                          BufferizationState &state) const {
+    return bufferizeLinalgOp(b, cast<LinalgOp>(op), state);
   }
 };
 
@@ -2109,9 +2049,7 @@ struct InitTensorOpInterface
   }
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo,
-                          AllocationCallbacks &allocationFn) const {
+                          BufferizationState &state) const {
     auto initTensorOp = cast<linalg::InitTensorOp>(op);
 
     // The InitTensorOp may have been eliminated.
@@ -2123,9 +2061,8 @@ struct InitTensorOpInterface
     b.setInsertionPoint(initTensorOp);
 
     Value alloc = createNewAllocDeallocPairForShapedValue(
-        b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo,
-        allocationFn);
-    map(bvm, initTensorOp.result(), alloc);
+        b, initTensorOp->getLoc(), initTensorOp.result(), state);
+    state.mapBuffer(initTensorOp.result(), alloc);
     return success();
   }
 };
@@ -2178,9 +2115,7 @@ struct TiledLoopOpInterface
   bool isAllocationHoistingBarrier(Operation *op) const { return true; }
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo,
-                          AllocationCallbacks &allocationFn) const {
+                          BufferizationState &state) const {
     auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
 
     // Take a guard before anything else.
@@ -2222,15 +2157,14 @@ struct TiledLoopOpInterface
 
       const OpResult &opResult = tiledLoopOp->getResult(resultIndex);
       OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex);
-      Value resultBuffer =
-          getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
+      Value resultBuffer = getResultBuffer(b, opResult, state);
       if (!resultBuffer)
         return failure();
 
       // Insert mapping and aliasing info.
-      aliasInfo.createAliasInfoEntry(resultBuffer);
-      aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer);
-      map(bvm, opResult, resultBuffer);
+      state.aliasInfo.createAliasInfoEntry(resultBuffer);
+      state.aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer);
+      state.mapBuffer(opResult, resultBuffer);
 
       // Insert new operand and bbArg.
       tiledLoopOp->insertOperands(nextOutputOperandIndex, resultBuffer);
@@ -2238,9 +2172,10 @@ struct TiledLoopOpInterface
           body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType());
       BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex);
       // Insert mapping and aliasing info.
-      aliasInfo.createAliasInfoEntry(newBufferBBArg);
-      aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg);
-      map(bvm, oldTensorBBArg, newBufferBBArg);
+      state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
+      state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
+                                                 newBufferBBArg);
+      state.mapBuffer(oldTensorBBArg, newBufferBBArg);
 
       // Set operand of `linalg.yield` to the bbArg so it just canonicalizes
       // away later.
@@ -2268,8 +2203,7 @@ struct TiledLoopOpInterface
         continue;
       }
 
-      Value inputBuffer = lookup(bvm, oldInputTensor);
-      assert(inputBuffer && " missing buffer for operand");
+      Value inputBuffer = state.lookupBuffer(oldInputTensor);
 
       // Insert new operand and bbArg.
       tiledLoopOp->insertOperands(nextInputOperandIndex, inputBuffer);
@@ -2278,9 +2212,10 @@ struct TiledLoopOpInterface
       BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex);
 
       // Insert mapping and aliasing info.
-      aliasInfo.createAliasInfoEntry(newBufferBBArg);
-      aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg);
-      map(bvm, oldTensorBBArg, newBufferBBArg);
+      state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
+      state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
+                                                 newBufferBBArg);
+      state.mapBuffer(oldTensorBBArg, newBufferBBArg);
 
       // Increment indices.
       ++numNewInputBuffers;
@@ -2318,9 +2253,7 @@ struct YieldOpInterface
   }
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo,
-                          AllocationCallbacks &allocationFn) const {
+                          BufferizationState &state) const {
     auto yieldOp = cast<linalg::YieldOp>(op);
 
     // Take a guard before anything else.
@@ -2394,9 +2327,7 @@ struct IfOpInterface
   }
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo,
-                          AllocationCallbacks &allocationFn) const {
+                          BufferizationState &state) const {
     // scf::IfOp is bufferized after scf::YieldOp in the else branch.
     return success();
   }
@@ -2405,9 +2336,7 @@ struct IfOpInterface
 /// Bufferize the scf::IfOp. This function is called after the YieldOp was
 /// bufferized.
 static LogicalResult bufferizeIfOp(scf::IfOp ifOp, OpBuilder &b,
-                                   BlockAndValueMapping &bvm,
-                                   BufferizationAliasInfo &aliasInfo,
-                                   AllocationCallbacks &allocationFn) {
+                                   BufferizationState &state) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(ifOp);
@@ -2420,13 +2349,12 @@ static LogicalResult bufferizeIfOp(scf::IfOp ifOp, OpBuilder &b,
     assert(opResult.getType().isa<RankedTensorType>() &&
            "unsupported unranked tensor");
 
-    Value resultBuffer =
-        getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
+    Value resultBuffer = getResultBuffer(b, opResult, state);
     if (!resultBuffer)
       return failure();
 
-    aliasInfo.createAliasInfoEntry(resultBuffer);
-    map(bvm, opResult, resultBuffer);
+    state.aliasInfo.createAliasInfoEntry(resultBuffer);
+    state.mapBuffer(opResult, resultBuffer);
   }
 
   return success();
@@ -2477,9 +2405,7 @@ struct ForOpInterface
   }
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo,
-                          AllocationCallbacks &allocationFn) const {
+                          BufferizationState &state) const {
     // Note: This method is just setting up the mappings for the block arguments
     // and the result buffer. The op is bufferized after the scf::YieldOp.
 
@@ -2497,17 +2423,16 @@ struct ForOpInterface
              "unsupported unranked tensor");
 
       // TODO: More general: Matching bbArg does not bufferize to a read.
-      Value resultBuffer =
-          getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
+      Value resultBuffer = getResultBuffer(b, opResult, state);
       if (!resultBuffer)
         return failure();
 
       OpOperand &opOperand = forOp.getOpOperandForResult(opResult);
       BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
-      aliasInfo.createAliasInfoEntry(resultBuffer);
-      aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
-      map(bvm, bbArg, resultBuffer);
-      map(bvm, opResult, resultBuffer);
+      state.aliasInfo.createAliasInfoEntry(resultBuffer);
+      state.aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
+      state.mapBuffer(bbArg, resultBuffer);
+      state.mapBuffer(opResult, resultBuffer);
     }
 
     return success();
@@ -2517,9 +2442,7 @@ struct ForOpInterface
 /// Bufferize the scf::ForOp. This function is called after the YieldOp was
 /// bufferized.
 static LogicalResult bufferizeForOp(scf::ForOp forOp, OpBuilder &b,
-                                    BlockAndValueMapping &bvm,
-                                    BufferizationAliasInfo &aliasInfo,
-                                    AllocationCallbacks &allocationFn) {
+                                    BufferizationState &state) {
   auto yieldOp = cast<scf::YieldOp>(&forOp.region().front().back());
   for (OpOperand &operand : yieldOp->getOpOperands()) {
     auto tensorType = operand.get().getType().dyn_cast<TensorType>();
@@ -2529,9 +2452,10 @@ static LogicalResult bufferizeForOp(scf::ForOp forOp, OpBuilder &b,
     OpOperand &forOperand = forOp.getOpOperandForResult(
         forOp->getResult(operand.getOperandNumber()));
     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
-    Value yieldedBuffer = lookup(bvm, operand.get());
-    Value bbArgBuffer = lookup(bvm, bbArg);
-    if (!aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, bbArgBuffer)) {
+    Value yieldedBuffer = state.lookupBuffer(operand.get());
+    Value bbArgBuffer = state.lookupBuffer(bbArg);
+    if (!state.aliasInfo.areEquivalentBufferizedValues(yieldedBuffer,
+                                                       bbArgBuffer)) {
       // TODO: this could get resolved with copies but it can also turn into
       // swaps so we need to be careful about order of copies.
       return yieldOp->emitError()
@@ -2567,9 +2491,7 @@ struct YieldOpInterface
   }
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo,
-                          AllocationCallbacks &allocationFn) const {
+                          BufferizationState &state) const {
     auto yieldOp = cast<scf::YieldOp>(op);
 
     if (auto execOp = dyn_cast<scf::ExecuteRegionOp>(yieldOp->getParentOp())) {
@@ -2584,12 +2506,12 @@ struct YieldOpInterface
     if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
       if (ifOp.elseYield() != yieldOp)
         return success();
-      return bufferizeIfOp(ifOp, b, bvm, aliasInfo, allocationFn);
+      return bufferizeIfOp(ifOp, b, state);
     }
 
     // Bufferize scf::ForOp after bufferizing the scf::YieldOp.
     if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp()))
-      return bufferizeForOp(forOp, b, bvm, aliasInfo, allocationFn);
+      return bufferizeForOp(forOp, b, state);
 
     return yieldOp->emitError("expected scf::ForOp parent for scf::YieldOp");
   }
@@ -2635,9 +2557,7 @@ struct CallOpInterface
   }
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo,
-                          AllocationCallbacks &allocationFn) const {
+                          BufferizationState &state) const {
     llvm_unreachable("CallOps are handled separately");
     return failure();
   }
@@ -2659,9 +2579,7 @@ struct ReturnOpInterface
   }
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo,
-                          AllocationCallbacks &allocationFn) const {
+                          BufferizationState &state) const {
     auto returnOp = cast<ReturnOp>(op);
 
     // Take a guard before anything else.
@@ -2675,12 +2593,11 @@ struct ReturnOpInterface
       auto tensorType = operand.get().getType().dyn_cast<TensorType>();
       if (!tensorType)
         continue;
-      Value v = lookup(bvm, operand.get());
-      assert(v && "missing buffer for result");
+      Value v = state.lookupBuffer(operand.get());
       Value returnTensor = b.create<memref::TensorLoadOp>(returnOp.getLoc(), v);
       operand.set(returnTensor);
-      aliasInfo.insertNewBufferEquivalence(returnTensor, v);
-      map(bvm, returnTensor, v);
+      state.aliasInfo.insertNewBufferEquivalence(returnTensor, v);
+      state.mapBuffer(returnTensor, v);
     }
     return success();
   }
@@ -2715,17 +2632,14 @@ struct CastOpInterface
   }
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo,
-                          AllocationCallbacks &allocationFn) const {
+                          BufferizationState &state) const {
     auto castOp = cast<tensor::CastOp>(op);
 
     // Take a guard before anything else.
     OpBuilder::InsertionGuard g(b);
     b.setInsertionPoint(castOp);
 
-    Value resultBuffer =
-        getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo, allocationFn);
+    Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state);
     if (!resultBuffer)
       return failure();
     Type sourceType = resultBuffer.getType();
@@ -2744,8 +2658,8 @@ struct CastOpInterface
         castOp.getResult().getType(), layout, memorySpace);
     Value res =
         b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
-    aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
-    map(bvm, castOp.getResult(), res);
+    state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
+    state.mapBuffer(castOp.getResult(), res);
     return success();
   }
 };
@@ -2766,9 +2680,7 @@ struct DimOpInterface
   }
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo,
-                          AllocationCallbacks &allocationFn) const {
+                          BufferizationState &state) const {
     auto dimOp = cast<tensor::DimOp>(op);
 
     // Take a guard before anything else.
@@ -2776,8 +2688,7 @@ struct DimOpInterface
     b.setInsertionPoint(dimOp);
 
     if (dimOp.source().getType().isa<RankedTensorType>()) {
-      Value v = lookup(bvm, dimOp.source());
-      assert(v && "missing buffer");
+      Value v = state.lookupBuffer(dimOp.source());
       dimOp.result().replaceAllUsesWith(
           b.create<memref::DimOp>(dimOp.getLoc(), v, dimOp.index()));
     }
@@ -2812,9 +2723,7 @@ struct ExtractSliceOpInterface
   }
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo,
-                          AllocationCallbacks &allocationFn) const {
+                          BufferizationState &state) const {
     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
 
     // Take a guard before anything else.
@@ -2824,18 +2733,16 @@ struct ExtractSliceOpInterface
 
     Location loc = extractSliceOp.getLoc();
     // Bail if source was not bufferized.
-    Value srcMemref = lookup(bvm, extractSliceOp.source());
-    if (!srcMemref)
-      return failure();
+    Value srcMemref = state.lookupBuffer(extractSliceOp.source());
     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
     auto dstTensorType =
         extractSliceOp.result().getType().cast<RankedTensorType>();
 
     // If not inplaceable, alloc.
     Value alloc;
-    if (!aliasInfo.isInPlace(extractSliceOp->getResult(0)))
+    if (!state.aliasInfo.isInPlace(extractSliceOp->getResult(0)))
       alloc = createNewAllocDeallocPairForShapedValue(
-          b, loc, extractSliceOp.result(), aliasInfo, allocationFn);
+          b, loc, extractSliceOp.result(), state);
 
     // Set insertion point now that potential alloc/dealloc are introduced.
     b.setInsertionPoint(extractSliceOp);
@@ -2851,17 +2758,18 @@ struct ExtractSliceOpInterface
         loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
         extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
     // Insert new alias.
-    aliasInfo.insertNewBufferAlias(subView, srcMemref);
+    state.aliasInfo.insertNewBufferAlias(subView, srcMemref);
 
     /// If not inplaceable, copy.
     if (alloc) {
       // Do not copy if the copied data is never read.
       if (isValueRead(extractSliceOp.result()))
-        allocationFn.memCpyFn(b, extractSliceOp.getLoc(), subView, alloc);
+        state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView,
+                                     alloc);
       subView = alloc;
     }
 
-    map(bvm, extractSliceOp.result(), subView);
+    state.mapBuffer(extractSliceOp.result(), subView);
     return success();
   }
 };
@@ -2882,9 +2790,7 @@ struct ExtractOpInterface
   }
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo,
-                          AllocationCallbacks &allocationFn) const {
+                          BufferizationState &state) const {
     auto extractOp = cast<tensor::ExtractOp>(op);
 
     // Take a guard before anything else.
@@ -2892,7 +2798,7 @@ struct ExtractOpInterface
     b.setInsertionPoint(extractOp);
 
     Location loc = extractOp.getLoc();
-    Value srcMemref = lookup(bvm, extractOp.tensor());
+    Value srcMemref = state.lookupBuffer(extractOp.tensor());
     Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
     extractOp.replaceAllUsesWith(l);
     return success();
@@ -2950,9 +2856,7 @@ struct InsertSliceOpInterface
   }
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo,
-                          AllocationCallbacks &allocationFn) const {
+                          BufferizationState &state) const {
     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
 
     // Take a guard before anything else.
@@ -2969,15 +2873,12 @@ struct InsertSliceOpInterface
     // TODO: be very loud about it or even consider failing the pass.
     // Alloc a copy for `insertSliceOp.dest()`, it will become the result
     // buffer.
-    Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), bvm,
-                                      aliasInfo, allocationFn);
+    Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state);
     if (!dstMemref)
       return failure();
     auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
 
-    Value srcMemref = lookup(bvm, insertSliceOp.source());
-    if (!srcMemref)
-      return failure();
+    Value srcMemref = state.lookupBuffer(insertSliceOp.source());
     auto subviewMemRefType =
         memref::SubViewOp::inferRankReducedResultType(
             insertSliceOp.getSourceType().getRank(), dstMemrefType,
@@ -2991,9 +2892,9 @@ struct InsertSliceOpInterface
     //   - The result is not inplace. This is the case where the whole tensor is
     //     cloned and the clone needs to be updated.
     // TODO: Is this necessary?
-    if (!isSourceEquivalentToAMatchingInplaceExtractSliceOp(aliasInfo,
+    if (!isSourceEquivalentToAMatchingInplaceExtractSliceOp(state.aliasInfo,
                                                             insertSliceOp) ||
-        !aliasInfo.isInPlace(insertSliceOp->getResult(0))) {
+        !state.aliasInfo.isInPlace(insertSliceOp->getResult(0))) {
       LDBG("insert_slice needs extra source copy: " << insertSliceOp.source()
                                                     << " -> copy\n");
       // Take a subview of the dst.
@@ -3001,11 +2902,12 @@ struct InsertSliceOpInterface
           loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
           insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
       // Insert new alias.
-      aliasInfo.insertNewBufferAlias(subView, dstMemref);
-      allocationFn.memCpyFn(b, insertSliceOp.getLoc(), srcMemref, subView);
+      state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
+      state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref,
+                                   subView);
     }
 
-    map(bvm, insertSliceOp.result(), dstMemref);
+    state.mapBuffer(insertSliceOp.result(), dstMemref);
 
     return success();
   }
@@ -3035,9 +2937,7 @@ struct TransferReadOpInterface
   }
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo,
-                          AllocationCallbacks &allocationFn) const {
+                          BufferizationState &state) const {
     auto transferReadOp = cast<vector::TransferReadOp>(op);
 
     // Take a guard before anything else.
@@ -3048,8 +2948,7 @@ struct TransferReadOpInterface
       return failure();
 
     // TransferReadOp always reads from the bufferized op.source().
-    Value v = lookup(bvm, transferReadOp.source());
-    assert(v && "missing buffer");
+    Value v = state.lookupBuffer(transferReadOp.source());
     transferReadOp.sourceMutable().assign(v);
     return success();
   }
@@ -3086,9 +2985,7 @@ struct TransferWriteOpInterface
   }
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo,
-                          AllocationCallbacks &allocationFn) const {
+                          BufferizationState &state) const {
     auto writeOp = cast<vector::TransferWriteOp>(op);
 
     // Take a guard before anything else.
@@ -3101,15 +2998,14 @@ struct TransferWriteOpInterface
     // Create a new transfer_write on buffer that doesn't have a return value.
     // Leave the previous transfer_write to dead code as it still has uses at
     // this point.
-    Value resultBuffer =
-        getResultBuffer(b, op->getResult(0), bvm, aliasInfo, allocationFn);
+    Value resultBuffer = getResultBuffer(b, op->getResult(0), state);
     if (!resultBuffer)
       return failure();
     b.create<vector::TransferWriteOp>(
         writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
         writeOp.permutation_map(),
         writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr());
-    map(bvm, op->getResult(0), resultBuffer);
+    state.mapBuffer(op->getResult(0), resultBuffer);
 
     return success();
   }


        


More information about the Mlir-commits mailing list